Sengxian před 2 roky
rodič
revize
1f32e7faa3
1 změnil soubory, kde provedl 2 přidání a 2 odebrání
  1. 2 2
      generation/strategies.py

+ 2 - 2
generation/strategies.py

@@ -119,9 +119,9 @@ class BeamSearchStrategy:
         next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
         probs = F.softmax(next_token_scores, dim=-1)
+        if num_beams < self.num_beams:  # First token
+            probs = probs[..., :vocab_size]
         if self.deterministic:
-            if num_beams < self.num_beams:  # First token
-                probs = probs[..., :vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else:
             next_tokens = torch.multinomial(