Prechádzať zdrojové kódy

Add deterministic option in config and BaseStrategy

duzx16 2 rokov pred
rodič
commit
20feedb486
3 zmenil súbory, kde vykonal 12 pridanie a 5 odobranie
  1. 1 0
      evaluation/configs.py
  2. 2 2
      evaluation/tasks.py
  3. 9 3
      generation/strategies.py

+ 1 - 0
evaluation/configs.py

@@ -50,6 +50,7 @@ class GenerationTaskConfig(BaseConfig):
     no_repeat_ngram_size: int = 3
     min_gen_length: int = 0
     max_gen_length: int = 128
+    deterministic: bool = False
     end_tokens: List[str] = field(default_factory=lambda: [])
 
 

+ 2 - 2
evaluation/tasks.py

@@ -185,7 +185,7 @@ class GenerationTask(BaseTask, ABC):
             print_rank_0(f"End tokens {end_tokens}")
         if self.config.sampling_strategy == "BaseStrategy":
             self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
-                                         end_tokens=end_tokens)
+                                         end_tokens=end_tokens, deterministic=self.config.deterministic)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
                 self.config.micro_batch_size,
@@ -195,7 +195,7 @@ class GenerationTask(BaseTask, ABC):
                 end_tokens=end_tokens,
                 no_repeat_ngram_size=self.config.no_repeat_ngram_size,
                 min_gen_length=self.config.min_gen_length,
-                deterministic=False,  # For evaluation, we need a determined generation strategy
+                deterministic=self.config.deterministic,  # For evaluation, we need a determined generation strategy
             )
         else:
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")

+ 9 - 3
generation/strategies.py

@@ -4,7 +4,8 @@ import torch.nn.functional as F
 from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
 
 class BaseStrategy:
-    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None,
+                 deterministic=False):
         self.batch_size = batch_size
         self.invalid_slices = invalid_slices
         self.temperature = temperature
@@ -14,6 +15,8 @@ class BaseStrategy:
         if end_tokens is None:
             end_tokens = []
         self.end_tokens = end_tokens
+        self.deterministic = deterministic
+        print(self.deterministic)
         self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
     @property
@@ -30,8 +33,11 @@ class BaseStrategy:
             logits[..., invalid_slice] = -65504
 
         logits = top_k_logits(logits, self.topk, self.top_p)
-        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
-        pred = torch.multinomial(probs, num_samples=1)
+        if self.deterministic:
+            pred = logits.max(dim=-1)[1]
+        else:
+            probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+            pred = torch.multinomial(probs, num_samples=1)
         for i in range(self.batch_size):
             if i >= batch_size:
                 self._is_done[i] = True