Bläddra i källkod

Fix finalize in BeamSearchStrategy

Zhengxiao Du 3 år sedan
förälder
incheckning
7be5ba1758
1 ändrade filer med 2 tillägg och 2 borttagningar
  1. 2 2
      generation/strategies.py

+ 2 - 2
generation/strategies.py

@@ -183,10 +183,10 @@ class BeamSearchStrategy:
 
     def finalize(self, tokens, mems):
         if self.consider_end:
-            batch_size = tokens.shape[0]
+            batch_size, num_beams = tokens.shape[:2]
             for batch_idx in range(batch_size):
                 if not self._is_done[batch_idx]:
-                    for i in range(batch_size):
+                    for i in range(num_beams):
                         self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
             mems = None
             ret = self.end_beams[:batch_size]