|
@@ -119,9 +119,9 @@ class BeamSearchStrategy:
|
|
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
|
|
|
|
|
probs = F.softmax(next_token_scores, dim=-1)
|
|
|
+ if num_beams < self.num_beams: # First token
|
|
|
+ probs = probs[..., :vocab_size]
|
|
|
if self.deterministic:
|
|
|
- if num_beams < self.num_beams: # First token
|
|
|
- probs = probs[..., :vocab_size]
|
|
|
next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
|
|
|
else:
|
|
|
next_tokens = torch.multinomial(
|