Przeglądaj źródła

Remove max_gen_length argument in generate_text

duzx16 3 lat temu
rodzic
commit
2948e546b1
2 zmienionych plików z 3 dodań i 3 usunięć
  1. 2 1
      evaluation/model.py
  2. 1 2
      evaluation/tasks.py

+ 2 - 1
evaluation/model.py

@@ -117,7 +117,7 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
         return log_probs
 
-    def generate_text(self, sample, strategy, return_all_beams=False, max_gen_length=128) -> Union[
+    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
         List[int], List[List[int]]]:
         """
         @return: A list of text model generated, sorted by score in descending order
@@ -128,6 +128,7 @@ class ModelForEvaluation(torch.nn.Module):
 
         def get_masks_and_position_ids(seq):
             batch_size = seq.shape[0]
+            max_gen_length = sample['target_position_ids'].shape[-1]
             tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
             position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
             position_ids = position_ids.to(device=torch.cuda.current_device()).long()

+ 1 - 2
evaluation/tasks.py

@@ -189,8 +189,7 @@ class GenerationTask(BaseTask, ABC):
     def predict_single_batch(self, batch) -> List[List[int]]:
         # micro batch size = 1 for generation task,
         # but we still need to return a list of predictions for consistency
-        output = self.model.generate_text(batch, self.strategy, return_all_beams=False,
-                                          max_gen_length=self.config.max_gen_length)
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
         return output