|
@@ -183,10 +183,10 @@ class BeamSearchStrategy:
|
|
|
|
|
|
def finalize(self, tokens, mems):
|
|
|
if self.consider_end:
|
|
|
- batch_size = tokens.shape[0]
|
|
|
+ batch_size, num_beams = tokens.shape[:2]
|
|
|
for batch_idx in range(batch_size):
|
|
|
if not self._is_done[batch_idx]:
|
|
|
- for i in range(batch_size):
|
|
|
+ for i in range(num_beams):
|
|
|
self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
|
|
|
mems = None
|
|
|
ret = self.end_beams[:batch_size]
|