Sengxian 2 роки тому
батько
коміт
1f32e7faa3
1 змінених файлів з 2 додано та 2 видалено
  1. 2 2
      generation/strategies.py

+ 2 - 2
generation/strategies.py

@@ -119,9 +119,9 @@ class BeamSearchStrategy:
         next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
         probs = F.softmax(next_token_scores, dim=-1)
+        if num_beams < self.num_beams:  # First token
+            probs = probs[..., :vocab_size]
         if self.deterministic:
-            if num_beams < self.num_beams:  # First token
-                probs = probs[..., :vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else:
             next_tokens = torch.multinomial(