2
0
Эх сурвалжийг харах

Fix finalize in BeamSearchStrategy

Zhengxiao Du 3 жил өмнө
parent
commit
7be5ba1758

+ 2 - 2
generation/strategies.py

@@ -183,10 +183,10 @@ class BeamSearchStrategy:
 
     def finalize(self, tokens, mems):
         if self.consider_end:
-            batch_size = tokens.shape[0]
+            batch_size, num_beams = tokens.shape[:2]
             for batch_idx in range(batch_size):
                 if not self._is_done[batch_idx]:
-                    for i in range(batch_size):
+                    for i in range(num_beams):
                         self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
             mems = None
             ret = self.end_beams[:batch_size]