|
@@ -188,7 +188,7 @@ class BeamSearchStrategy:
|
|
|
for batch_idx in range(batch_size):
|
|
|
if not self._is_done[batch_idx]:
|
|
|
for i in range(batch_size):
|
|
|
- self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[i], batch_idx)
|
|
|
+ self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
|
|
|
mems = None
|
|
|
ret = self.end_beams[:batch_size]
|
|
|
else:
|