Kaynağa Gözat

Remove redundant imports

Zhengxiao Du 3 yıl önce
ebeveyn
işleme
26543554f8
2 değiştirilmiş dosya ile 1 ekleme ve 2 silme
  1. 0 1
      generate.py
  2. 1 1
      generation/strategies.py

+ 0 - 1
generate.py

@@ -9,7 +9,6 @@ from typing import List, Tuple
 from SwissArmyTransformer import mpu
 from evaluation.model import batch_filling_sequence
 from generation import BeamSearchStrategy, BaseStrategy
-from generation import BeamSearchStrategy
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
 from initialize import initialize, initialize_model_and_tokenizer
 

+ 1 - 1
generation/strategies.py

@@ -124,7 +124,7 @@ class BeamSearchStrategy:
 
         probs = F.softmax(next_token_scores, dim=-1)
         if self.deterministic:
-            if mems.shape[2] < self.num_beams:  # First token
+            if num_beams < self.num_beams:  # First token
                 probs = probs[..., :vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else: