Pārlūkot izejas kodu

Implement multichoice task for GLM-10B

duzx16 2 gadi atpakaļ
vecāks
revīzija
1783d9b6c5

+ 13 - 0
configs/model_glm_10b.sh

@@ -0,0 +1,13 @@
+MODEL_TYPE="glm-10b"
+CHECKPOINT_PATH="/zhangpai21/checkpoints/glm-10b-sat"
+MP_SIZE=1
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --vocab 50304 \
+            --num-layers 48 \
+            --hidden-size 4096 \
+            --num-attention-heads 64 \
+            --max-sequence-length 1025 \
+            --tokenizer-type glm_GPT2BPETokenizer \
+            --tokenizer-model-type gpt2 \
+            --task-mask \
+            --load ${CHECKPOINT_PATH}"

+ 1 - 0
evaluation/configs.py

@@ -25,6 +25,7 @@ class BaseConfig(YAMLWizard):
     use_multitask_encoding: bool = False  # Not supported now
     unidirectional: bool = False  # Whether to use unidirectional attention
     max_seq_length: int = 2048  # Max sequence length
+    no_tokenized: bool = False
     file_pattern: str | Dict[str, str] = "**/*.json*"  # Organize data file in groups
 
     micro_batch_size: int = 1  # 'gen' task only support mbs = 1 for now

+ 74 - 6
evaluation/dataset.py

@@ -168,14 +168,19 @@ class GenerationTaskDataset(EvaluationDataset):
 
 
 class SmallGenerationTaskDataset(GenerationTaskDataset):
-    config: GenerationTaskConfig
+    def process_single_item(self, item):
+        text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
+        if len(text) + self.config.max_gen_length + 3 > self.config.max_seq_length:
+            text_length = self.config.max_seq_length - self.config.max_gen_length - 3
+            text = text[len(text) - text_length : len(text)]
+        return {"text": text, "targets": targets}
 
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
         tokenizer = get_tokenizer()
 
         sop_id = tokenizer.get_command("sop")
-        mask_id = tokenizer.get_command("[gMASK]").Id if use_task_mask else tokenizer.get_command("[MASK]").Id
+        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
         cls_id = tokenizer.get_command("ENC")
         eos_id = tokenizer.get_command("eos")
 
@@ -232,6 +237,10 @@ class MultiChoiceTaskDataset(EvaluationDataset):
     def has_collate_fn(self) -> bool:
         return True
 
+    @staticmethod
+    def num_special_tokens():
+        return 2
+
     @staticmethod
     def collate_fn(samples):
         TILE = 32
@@ -263,7 +272,9 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         }
 
     def process_single_item(self, item):
-        text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
+        text = get_tokenized_input(item, "inputs", no_tokenized=self.config.no_tokenized)
+        choices = get_tokenized_input(item, "choices", no_tokenized=self.config.no_tokenized)
+        label = item["label"]
 
         tgt_seq_length = sum([len(choice) for choice in choices])
         if tgt_seq_length == len(choices):
@@ -271,9 +282,9 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             tgt_seq_length = 1
 
         assert tgt_seq_length < self.config.max_seq_length
-        if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
-            text_length = self.config.max_seq_length - tgt_seq_length - 2
-            text = text[len(text) - text_length : len(text)]
+        if len(text) + tgt_seq_length + self.num_special_tokens() > self.config.max_seq_length:
+            text_length = self.config.max_seq_length - tgt_seq_length - self.num_special_tokens()
+            text = text[len(text) - text_length: len(text)]
 
         assert not (
             self.mask_id in text and self.config.use_multitask_encoding
@@ -354,6 +365,63 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         )
 
 
+class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
+    @staticmethod
+    def num_special_tokens():
+        return 3
+
+    @staticmethod
+    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
+        tokenizer = get_tokenizer()
+        cls_id = tokenizer.get_command("ENC")
+        eos_id = tokenizer.get_command("eos")
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[MASK]")
+        blank_filling = mask_id in text
+        if not blank_filling:
+            text = text + [mask_id]
+        text = [cls_id] + text + [eos_id]
+
+        token = np.array(text, dtype=np.int64)
+        target = np.array(text, dtype=np.int64)
+        position_id = np.arange(len(text), dtype=np.int64)
+        block_position_id = np.zeros(len(text), dtype=np.int64)
+        mask_position = text.index(mask_id)
+        choice_target_id = []
+
+
+        division = len(token)
+        attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
+
+        for choice in choices:
+            position_id = np.concatenate((position_id, [mask_position] * len(choice)))
+            block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
+            choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
+            attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
+            token = np.concatenate((token, [sop_id], choice[:-1]))
+            target = np.concatenate((target, choice))
+
+            if is_single_token:
+                break
+
+        attention_mask = block_diag(*attention_mask)
+        attention_mask[: len(token), :division] = 1
+
+        if is_single_token:
+            choices = np.array(choices, dtype=np.int64).squeeze().tolist()
+
+        position_id = np.stack((position_id, block_position_id), axis=0)
+
+        item = {
+            "token": token,
+            "position_id": position_id,
+            "attention_mask": attention_mask,
+            "choices": choices,
+            "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
+        }
+        return item
+
+
 class LanguageModelTaskDataset(EvaluationDataset):
     config: LanguageModelTaskConfig
 

+ 3 - 2
evaluation/tasks.py

@@ -14,7 +14,8 @@ from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceT
 from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
-from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset, SmallGenerationTaskDataset
+from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
+from .dataset import SmallGenerationTaskDataset, SmallMultiChoiceTaskDataset
 from .utils import build_data_loader, gather_result, print_rank_0
 from .metrics import DEFAULT_METRICS
 
@@ -199,7 +200,7 @@ class MultiChoiceTask(BaseTask, ABC):
         return MultiChoiceTaskConfig
 
     def build_dataset(self, relative_path):
-        return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
+        return SmallMultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
 
     def predict_single_batch(self, batch) -> List[int]:
         log_probs = self.model.cond_log_prob(batch)

+ 2 - 2
evaluation/utils.py

@@ -52,8 +52,8 @@ def gather_result(prediction, total_length, micro_batch_size):
     return prediction
 
 
-def get_tokenized_input(item, key):
-    if key in item:
+def get_tokenized_input(item, key, no_tokenized=False):
+    if key in item and not no_tokenized:
         return item[key]
     tokenizer = get_tokenizer()
     pretokenized_key = key + "_pretokenized"

+ 2 - 1
initialize.py

@@ -77,7 +77,8 @@ def initialize_model_and_tokenizer(args):
     # Load checkpoint
     torch.distributed.barrier()
     start = time.time()
-    load_checkpoint(model, args)
+    if args.load:
+        load_checkpoint(model, args)
     torch.distributed.barrier()
     if torch.distributed.get_rank() == 0:
         print(f"> Checkpoint loaded in {time.time() - start:.1f}s")

+ 3 - 1
tasks/mmlu/mmlu.yaml

@@ -7,4 +7,6 @@ file-pattern:
   social_sciences: "social_sciences/*.json"
   humanities: "humanities/*.json"
   other: "other/*.json"
-micro-batch-size: 1
+no-tokenized: true
+micro-batch-size: 8
+max_seq-length: 896