Browse Source

Implement unidirectional multichoice

duzx16 2 years ago
parent
commit
dff6780e6b
1 changed files with 44 additions and 13 deletions
  1. 44 13
      evaluation/dataset.py

+ 44 - 13
evaluation/dataset.py

@@ -371,41 +371,62 @@ class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
         return 3
         return 3
 
 
     @staticmethod
     @staticmethod
-    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
+    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
+                                     unidirectional=False):
         tokenizer = get_tokenizer()
         tokenizer = get_tokenizer()
         cls_id = tokenizer.get_command("ENC")
         cls_id = tokenizer.get_command("ENC")
         eos_id = tokenizer.get_command("eos")
         eos_id = tokenizer.get_command("eos")
         sop_id = tokenizer.get_command("sop")
         sop_id = tokenizer.get_command("sop")
-        mask_id = tokenizer.get_command("[MASK]")
+        mask_id = tokenizer.get_command("[MASK]") if not unidirectional else tokenizer.get_command("[gMASK]")
         blank_filling = mask_id in text
         blank_filling = mask_id in text
-        if not blank_filling:
-            text = text + [mask_id]
-        text = [cls_id] + text + [eos_id]
+        text_length = len(text)
+        last_token = text[-1]
+
+        if unidirectional:
+            assert not blank_filling
+            text = [cls_id, mask_id, eos_id, sop_id] + text[:-1]
+            position_id = np.array(list(range(3)) + [1] * text_length, dtype=np.int64)
+            block_position_id = np.array([0] * 3 + list(range(1, text_length + 1)), dtype=np.int64)
+        else:
+            if not blank_filling:
+                text = text + [mask_id]
+            text = [cls_id] + text + [eos_id]
+            position_id = np.arange(len(text), dtype=np.int64)
+            block_position_id = np.zeros(len(text), dtype=np.int64)
 
 
         token = np.array(text, dtype=np.int64)
         token = np.array(text, dtype=np.int64)
         target = np.array(text, dtype=np.int64)
         target = np.array(text, dtype=np.int64)
-        position_id = np.arange(len(text), dtype=np.int64)
-        block_position_id = np.zeros(len(text), dtype=np.int64)
         mask_position = text.index(mask_id)
         mask_position = text.index(mask_id)
         choice_target_id = []
         choice_target_id = []
 
 
-
         division = len(token)
         division = len(token)
-        attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
+        if unidirectional:
+            attention_mask = [np.tril(np.ones((len(token), len(token)), dtype=np.int64))]
+            attention_mask[0][:3, :3] = 1
+        else:
+            attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
 
 
         for choice in choices:
         for choice in choices:
-            position_id = np.concatenate((position_id, [mask_position] * len(choice)))
-            block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
             choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
             choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
             attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
             attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
-            token = np.concatenate((token, [sop_id], choice[:-1]))
+            position_id = np.concatenate((position_id, [mask_position] * len(choice)))
+            if unidirectional:
+                block_position_id = np.concatenate(
+                    (block_position_id, range(1 + text_length, 1 + text_length + len(choice))))
+                token = np.concatenate((token, [last_token], choice[:-1]))
+            else:
+                block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
+                token = np.concatenate((token, [sop_id], choice[:-1]))
             target = np.concatenate((target, choice))
             target = np.concatenate((target, choice))
 
 
             if is_single_token:
             if is_single_token:
                 break
                 break
 
 
         attention_mask = block_diag(*attention_mask)
         attention_mask = block_diag(*attention_mask)
-        attention_mask[: len(token), :division] = 1
+        if unidirectional:
+            attention_mask[division:, :division] = 1
+        else:
+            attention_mask[: len(token), :division] = 1
 
 
         if is_single_token:
         if is_single_token:
             choices = np.array(choices, dtype=np.int64).squeeze().tolist()
             choices = np.array(choices, dtype=np.int64).squeeze().tolist()
@@ -421,6 +442,16 @@ class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
         }
         }
         return item
         return item
 
 
+    def __getitem__(self, idx):
+        item = self.data[idx]
+        return self.build_multiple_choice_sample(
+            item["text"],
+            item["choices"],
+            is_single_token=self.is_single_token,
+            unified_multitask_encoding=self.config.use_multitask_encoding,
+            unidirectional=self.config.unidirectional
+        )
+
 
 
 class LanguageModelTaskDataset(EvaluationDataset):
 class LanguageModelTaskDataset(EvaluationDataset):
     config: LanguageModelTaskConfig
     config: LanguageModelTaskConfig