소스 검색

Fix unidirectional language modeling evaluation

Sengxian 2 년 전
부모
커밋
66725ecf50
1개의 변경된 파일15개의 추가작업 그리고 13개의 파일을 삭제
  1. 15 13
      evaluation/dataset.py

+ 15 - 13
evaluation/dataset.py

@@ -79,6 +79,8 @@ class GenerationTaskDataset(EvaluationDataset):
 
     def process_single_item(self, item, **kwargs):
         text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
+        if len(targets) and (not isinstance(targets[0], list)):
+            targets = [targets]
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text = text[len(text) - text_length : len(text)]
@@ -102,8 +104,8 @@ class GenerationTaskDataset(EvaluationDataset):
             token_batch.append(token)
             position_id_batch.append(position_id)
             attention_mask_batch.append(attention_mask)
-            context_length_batch.append(sample['context_length'])
-            target_position_id_batch.append(sample['target_position_id'])
+            context_length_batch.append(sample["context_length"])
+            target_position_id_batch.append(sample["target_position_id"])
         return {
             "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
             "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
@@ -138,7 +140,7 @@ class GenerationTaskDataset(EvaluationDataset):
         position_id = np.arange(0, context_length, dtype=np.int64)
         target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
-            position_id[context_length - 1:] = mask_position
+            position_id[context_length - 1 :] = mask_position
             target_position_id[:] = mask_position
 
         attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
@@ -222,16 +224,12 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         if tgt_seq_length != 1:
             self.is_single_token = False
 
-        return [{
-            "text": text,
-            "choices": choices,
-            "label": label,
-            **kwargs
-        }]
+        return [{"text": text, "choices": choices, "label": label, **kwargs}]
 
     @staticmethod
-    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
-                                     unidirectional=False, use_task_mask=False):
+    def build_multiple_choice_sample(
+        text, choices, is_single_token, unified_multitask_encoding=False, unidirectional=False, use_task_mask=False
+    ):
         tokenizer = get_tokenizer()
 
         sop_id = tokenizer.get_command("sop")
@@ -267,7 +265,7 @@ class MultiChoiceTaskDataset(EvaluationDataset):
 
         for choice in choices:
             if not choice:
-                choice = [tokenizer.get_command('eop')]
+                choice = [tokenizer.get_command("eop")]
             position_id = np.concatenate(
                 (
                     position_id,
@@ -368,10 +366,14 @@ class LanguageModelTaskDataset(EvaluationDataset):
         attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
         attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
 
+        gen_length = min(len(text), self.config.generation_length)
         return {
             "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
             "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
             "position_ids": np.arange(0, seq_length, dtype=np.int64),
             "attention_mask": attention_mask < 0.5,
-            "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
+            "loss_masks": np.array(
+                [0] * (seq_length - gen_length) + [1] * gen_length,
+                dtype=np.int64,
+            ),
         }