Prechádzať zdrojové kódy

Support 2D position encoding for generation task

Sengxian 2 rokov pred
rodič
commit
389fdf93b1
2 zmenil súbory, kde vykonal 63 pridanie a 43 odobranie
  1. 1 43
      evaluation/dataset.py
  2. 62 0
      evaluation/model.py

+ 1 - 43
evaluation/dataset.py

@@ -119,51 +119,9 @@ class GenerationTaskDataset(EvaluationDataset):
             "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
         }
 
-    @staticmethod
-    def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
-        tokenizer = get_tokenizer()
-
-        sop_id = tokenizer.get_command("sop")
-        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
-
-        token = np.array(text, dtype=np.int64)
-
-        blank_filling = mask_id in text
-        if blank_filling:
-            assert not unidirectional, "Unidirectional attention doesn't support blank filling"
-            assert not use_task_mask, "Unidirectional attention doesn't support task mask"
-            mask_position = text.index(mask_id)
-            token = np.concatenate((token, [sop_id]))
-        else:
-            mask_position = len(token)
-            if unidirectional:
-                token = np.concatenate(([mask_id, sop_id], token))
-            else:
-                token = np.concatenate((token, [mask_id, sop_id]))
-        context_length = len(token)
-
-        position_id = np.arange(0, context_length, dtype=np.int64)
-        target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
-        if not use_task_mask:
-            position_id[context_length - 1 :] = mask_position
-            target_position_id[:] = mask_position
-
-        attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
-        if not unidirectional:
-            attention_mask[: context_length - 1, : context_length - 1] = 1
-
-        item = {
-            "token": token,
-            "position_id": position_id,
-            "target_position_id": target_position_id,
-            "attention_mask": attention_mask,
-            "context_length": context_length,
-        }
-        return item
-
     def __getitem__(self, idx):
         item = self.data[idx]
-        sample = self.build_generation_sample(
+        sample = self.model.build_generation_sample(
             item["text"],
             max_gen_length=self.config.max_gen_length,
             use_task_mask=self.config.use_task_mask,

+ 62 - 0
evaluation/model.py

@@ -248,6 +248,68 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
         return log_probs
 
+    def build_generation_sample(self, text, max_gen_length, use_task_mask, unidirectional):
+        tokenizer = get_tokenizer()
+
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
+
+        token = np.array(text, dtype=np.int64)
+        position_id = np.arange(len(text), dtype=np.int64)
+        block_position_id = np.zeros(len(text), dtype=np.int64)
+        target_position_id = np.zeros(len(text), dtype=np.int64)
+        target_block_position_id = np.zeros(len(text), dtype=np.int64)
+
+        blank_filling = mask_id in text
+
+        if unidirectional:
+            assert use_task_mask, "Unidirectional attention only support gMASK"
+            assert not blank_filling, "Unidirectional attention doesn't support blank filling"
+            token = np.concatenate(([mask_id, sop_id], token))
+            if self.position_encoding_2d:
+                position_id = np.zeros(len(token), dtype=np.int64)
+                target_position_id = np.zeros(max_gen_length, dtype=np.int64)
+                block_position_id = np.arange(len(token), dtype=np.int64)
+                target_block_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
+            else:
+                position_id = np.arange(len(token), dtype=np.int64)
+                target_position_id = np.zeros(len(token), len(token) + max_gen_length, dtype=np.int64)
+        else:
+            if not blank_filling:
+                mask_position = len(token)
+                token = np.concatenate((token, [mask_id, sop_id]))
+            else:
+                assert not use_task_mask, "Blank filling only support MASK"
+                mask_position = text.index(mask_id)
+                token = np.concatenate((token, [sop_id]))
+
+            position_id = np.concatenate((np.arange(len(token) - 1, dtype=np.int64), [mask_position]))
+            target_position_id = np.full(max_gen_length, mask_position, dtype=np.int64)
+            if self.position_encoding_2d:
+                block_position_id = np.zeros(len(token), dtype=np.int64)
+                target_block_position_id = np.arange(1, max_gen_length + 1, dtype=np.int64)
+            elif use_task_mask:
+                position_id = np.arange(len(token), dtype=np.int64)
+                target_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
+
+        context_length = len(token)
+        attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
+        if not unidirectional:
+            attention_mask[: context_length - 1, : context_length - 1] = 1
+
+        if self.position_encoding_2d:
+            position_id = np.stack((position_id, block_position_id), axis=0)
+            target_position_id = np.stack((target_position_id, target_block_position_id), axis=0)
+
+        item = {
+            "token": token,
+            "position_id": position_id,
+            "target_position_id": target_position_id,
+            "attention_mask": attention_mask,
+            "context_length": context_length,
+        }
+        return item
+
     def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[List[int]], List[List[List[int]]]]:
         """
         @return: A list of text model generated, sorted by score in descending order