Browse Source

Implement unidirectional multiple choice

duzx16 2 years ago
parent
commit
39abc5ee41
1 changed files with 43 additions and 19 deletions
  1. 43 19
      evaluation/dataset.py

+ 43 - 19
evaluation/dataset.py

@@ -47,16 +47,8 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
         self.gmask_id = tokenizer.get_command("[gMASK]")
         self.gmask_id = tokenizer.get_command("[gMASK]")
 
 
         self.data = []
         self.data = []
-        if path.endswith("jsonl"):
-            with open(os.path.join(path), "r", encoding="utf-8") as file:
-                for line in file:
-                    item = json.loads(line)
-                    self.data.extend(self.process_single_item(item))
-        elif path.endswith("json"):
-            with open(os.path.join(path), "r", encoding="utf-8") as file:
-                dataset = json.load(file)
-            for item in dataset:
-                self.data.extend(self.process_single_item(item))
+        for p in self.path:
+            self.process_single_file(p)
 
 
     @property
     @property
     def has_collate_fn(self) -> bool:
     def has_collate_fn(self) -> bool:
@@ -65,6 +57,21 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
     def collate_fn(self, samples):
     def collate_fn(self, samples):
         return None
         return None
 
 
+    def process_single_file(self, path):
+        if not path.endswith("jsonl"):
+            try:
+                with open(os.path.join(path), "r", encoding="utf-8") as file:
+                    dataset = json.load(file)
+                for item in dataset:
+                    self.data.extend(self.process_single_item(item))
+                return
+            except json.decoder.JSONDecodeError:
+                pass
+        with open(os.path.join(path), "r", encoding="utf-8") as file:
+            for line in file:
+                item = json.loads(line)
+                self.data.extend(self.process_single_item(item))
+
     @abstractmethod
     @abstractmethod
     def process_single_item(self, item, **kwargs) -> List[dict]:
     def process_single_item(self, item, **kwargs) -> List[dict]:
         pass
         pass
@@ -201,11 +208,12 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         }]
         }]
 
 
     @staticmethod
     @staticmethod
-    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
+    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
+                                     unidirectional=False, use_task_mask=False):
         tokenizer = get_tokenizer()
         tokenizer = get_tokenizer()
 
 
         sop_id = tokenizer.get_command("sop")
         sop_id = tokenizer.get_command("sop")
-        mask_id = tokenizer.get_command("[MASK]")
+        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
 
 
         token = np.array(text, dtype=np.int64)
         token = np.array(text, dtype=np.int64)
         target = np.array(text, dtype=np.int64)
         target = np.array(text, dtype=np.int64)
@@ -214,15 +222,26 @@ class MultiChoiceTaskDataset(EvaluationDataset):
 
 
         blank_filling = mask_id in text
         blank_filling = mask_id in text
         if not blank_filling:
         if not blank_filling:
-            mask_position = len(token)
-            token = np.concatenate((token, [mask_id]))
-            target = np.concatenate((target, [mask_id]))
-            position_id = np.concatenate((position_id, [mask_position]))
+            if unidirectional:
+                assert use_task_mask
+                token = np.concatenate(([mask_id, sop_id], token[:-1]))
+                target = np.concatenate(([mask_id, sop_id], target[:-1]))
+                position_id = np.arange(len(token), dtype=np.int64)
+                mask_position = len(token)
+            else:
+                mask_position = len(token)
+                token = np.concatenate((token, [mask_id]))
+                target = np.concatenate((target, [mask_id]))
+                position_id = np.concatenate((position_id, [mask_position]))
         else:
         else:
+            assert not unidirectional, "Unidirectional attention doesn't support blank filling"
+            assert not use_task_mask, "Unidirectional attention doesn't support task mask"
             mask_position = text.index(mask_id)
             mask_position = text.index(mask_id)
 
 
         division = len(token)
         division = len(token)
         attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
         attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
+        if unidirectional:
+            attention_mask[0] = np.tril(attention_mask[0])
 
 
         for choice in choices:
         for choice in choices:
             if not choice:
             if not choice:
@@ -231,20 +250,23 @@ class MultiChoiceTaskDataset(EvaluationDataset):
                 (
                 (
                     position_id,
                     position_id,
                     [mask_position] * len(choice)
                     [mask_position] * len(choice)
-                    if blank_filling or not unified_multitask_encoding
+                    if (blank_filling or not unified_multitask_encoding) and not use_task_mask
                     else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
                     else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
                 )
                 )
             )
             )
             choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
             choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
             attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
             attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
-            token = np.concatenate((token, [sop_id], choice[:-1]))
+            if unidirectional:
+                token = np.concatenate((token, [text[-1]], choice[:-1]))
+            else:
+                token = np.concatenate((token, [sop_id], choice[:-1]))
             target = np.concatenate((target, choice))
             target = np.concatenate((target, choice))
 
 
             if is_single_token:
             if is_single_token:
                 break
                 break
 
 
         attention_mask = block_diag(*attention_mask)
         attention_mask = block_diag(*attention_mask)
-        attention_mask[: len(token), :division] = 1
+        attention_mask[division:, :division] = 1
 
 
         if is_single_token:
         if is_single_token:
             choices = np.array(choices, dtype=np.int64).squeeze().tolist()
             choices = np.array(choices, dtype=np.int64).squeeze().tolist()
@@ -265,6 +287,8 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             item["choices"],
             item["choices"],
             is_single_token=self.is_single_token,
             is_single_token=self.is_single_token,
             unified_multitask_encoding=self.config.use_multitask_encoding,
             unified_multitask_encoding=self.config.use_multitask_encoding,
+            unidirectional=self.config.unidirectional,
+            use_task_mask=self.config.use_task_mask,
         )
         )
         sample["label"] = item["label"]
         sample["label"] = item["label"]
         return sample
         return sample