Browse Source

Fix finalize in BeamSearchStrategy

Zhengxiao Du 3 years ago
parent
commit
7be5ba1758
1 changed files with 2 additions and 2 deletions
  1. 2 2
      generation/strategies.py

+ 2 - 2
generation/strategies.py

@@ -183,10 +183,10 @@ class BeamSearchStrategy:
 
 
     def finalize(self, tokens, mems):
     def finalize(self, tokens, mems):
         if self.consider_end:
         if self.consider_end:
-            batch_size = tokens.shape[0]
+            batch_size, num_beams = tokens.shape[:2]
             for batch_idx in range(batch_size):
             for batch_idx in range(batch_size):
                 if not self._is_done[batch_idx]:
                 if not self._is_done[batch_idx]:
-                    for i in range(batch_size):
+                    for i in range(num_beams):
                         self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
                         self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
             mems = None
             mems = None
             ret = self.end_beams[:batch_size]
             ret = self.end_beams[:batch_size]