瀏覽代碼

Implement multichoice task for GLM-10B

duzx16 2 年之前
父節點
當前提交
1783d9b6c5
共有 7 個文件被更改,包括 98 次插入12 次删除
  1. 13 0
      configs/model_glm_10b.sh
  2. 1 0
      evaluation/configs.py
  3. 74 6
      evaluation/dataset.py
  4. 3 2
      evaluation/tasks.py
  5. 2 2
      evaluation/utils.py
  6. 2 1
      initialize.py
  7. 3 1
      tasks/mmlu/mmlu.yaml

+ 13 - 0
configs/model_glm_10b.sh

@@ -0,0 +1,13 @@
+MODEL_TYPE="glm-10b"
+CHECKPOINT_PATH="/zhangpai21/checkpoints/glm-10b-sat"
+MP_SIZE=1
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --vocab 50304 \
+            --num-layers 48 \
+            --hidden-size 4096 \
+            --num-attention-heads 64 \
+            --max-sequence-length 1025 \
+            --tokenizer-type glm_GPT2BPETokenizer \
+            --tokenizer-model-type gpt2 \
+            --task-mask \
+            --load ${CHECKPOINT_PATH}"

+ 1 - 0
evaluation/configs.py

@@ -25,6 +25,7 @@ class BaseConfig(YAMLWizard):
     use_multitask_encoding: bool = False  # Not supported now
     use_multitask_encoding: bool = False  # Not supported now
     unidirectional: bool = False  # Whether to use unidirectional attention
     unidirectional: bool = False  # Whether to use unidirectional attention
     max_seq_length: int = 2048  # Max sequence length
     max_seq_length: int = 2048  # Max sequence length
+    no_tokenized: bool = False
     file_pattern: str | Dict[str, str] = "**/*.json*"  # Organize data file in groups
     file_pattern: str | Dict[str, str] = "**/*.json*"  # Organize data file in groups
 
 
     micro_batch_size: int = 1  # 'gen' task only support mbs = 1 for now
     micro_batch_size: int = 1  # 'gen' task only support mbs = 1 for now

+ 74 - 6
evaluation/dataset.py

@@ -168,14 +168,19 @@ class GenerationTaskDataset(EvaluationDataset):
 
 
 
 
 class SmallGenerationTaskDataset(GenerationTaskDataset):
 class SmallGenerationTaskDataset(GenerationTaskDataset):
-    config: GenerationTaskConfig
+    def process_single_item(self, item):
+        text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
+        if len(text) + self.config.max_gen_length + 3 > self.config.max_seq_length:
+            text_length = self.config.max_seq_length - self.config.max_gen_length - 3
+            text = text[len(text) - text_length : len(text)]
+        return {"text": text, "targets": targets}
 
 
     @staticmethod
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
         tokenizer = get_tokenizer()
         tokenizer = get_tokenizer()
 
 
         sop_id = tokenizer.get_command("sop")
         sop_id = tokenizer.get_command("sop")
-        mask_id = tokenizer.get_command("[gMASK]").Id if use_task_mask else tokenizer.get_command("[MASK]").Id
+        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
         cls_id = tokenizer.get_command("ENC")
         cls_id = tokenizer.get_command("ENC")
         eos_id = tokenizer.get_command("eos")
         eos_id = tokenizer.get_command("eos")
 
 
@@ -232,6 +237,10 @@ class MultiChoiceTaskDataset(EvaluationDataset):
     def has_collate_fn(self) -> bool:
     def has_collate_fn(self) -> bool:
         return True
         return True
 
 
+    @staticmethod
+    def num_special_tokens():
+        return 2
+
     @staticmethod
     @staticmethod
     def collate_fn(samples):
     def collate_fn(samples):
         TILE = 32
         TILE = 32
@@ -263,7 +272,9 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         }
         }
 
 
     def process_single_item(self, item):
     def process_single_item(self, item):
-        text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
+        text = get_tokenized_input(item, "inputs", no_tokenized=self.config.no_tokenized)
+        choices = get_tokenized_input(item, "choices", no_tokenized=self.config.no_tokenized)
+        label = item["label"]
 
 
         tgt_seq_length = sum([len(choice) for choice in choices])
         tgt_seq_length = sum([len(choice) for choice in choices])
         if tgt_seq_length == len(choices):
         if tgt_seq_length == len(choices):
@@ -271,9 +282,9 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             tgt_seq_length = 1
             tgt_seq_length = 1
 
 
         assert tgt_seq_length < self.config.max_seq_length
         assert tgt_seq_length < self.config.max_seq_length
-        if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
-            text_length = self.config.max_seq_length - tgt_seq_length - 2
-            text = text[len(text) - text_length : len(text)]
+        if len(text) + tgt_seq_length + self.num_special_tokens() > self.config.max_seq_length:
+            text_length = self.config.max_seq_length - tgt_seq_length - self.num_special_tokens()
+            text = text[len(text) - text_length: len(text)]
 
 
         assert not (
         assert not (
             self.mask_id in text and self.config.use_multitask_encoding
             self.mask_id in text and self.config.use_multitask_encoding
@@ -354,6 +365,63 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         )
         )
 
 
 
 
+class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
+    @staticmethod
+    def num_special_tokens():
+        return 3
+
+    @staticmethod
+    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
+        tokenizer = get_tokenizer()
+        cls_id = tokenizer.get_command("ENC")
+        eos_id = tokenizer.get_command("eos")
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[MASK]")
+        blank_filling = mask_id in text
+        if not blank_filling:
+            text = text + [mask_id]
+        text = [cls_id] + text + [eos_id]
+
+        token = np.array(text, dtype=np.int64)
+        target = np.array(text, dtype=np.int64)
+        position_id = np.arange(len(text), dtype=np.int64)
+        block_position_id = np.zeros(len(text), dtype=np.int64)
+        mask_position = text.index(mask_id)
+        choice_target_id = []
+
+
+        division = len(token)
+        attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
+
+        for choice in choices:
+            position_id = np.concatenate((position_id, [mask_position] * len(choice)))
+            block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
+            choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
+            attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
+            token = np.concatenate((token, [sop_id], choice[:-1]))
+            target = np.concatenate((target, choice))
+
+            if is_single_token:
+                break
+
+        attention_mask = block_diag(*attention_mask)
+        attention_mask[: len(token), :division] = 1
+
+        if is_single_token:
+            choices = np.array(choices, dtype=np.int64).squeeze().tolist()
+
+        position_id = np.stack((position_id, block_position_id), axis=0)
+
+        item = {
+            "token": token,
+            "position_id": position_id,
+            "attention_mask": attention_mask,
+            "choices": choices,
+            "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
+        }
+        return item
+
+
 class LanguageModelTaskDataset(EvaluationDataset):
 class LanguageModelTaskDataset(EvaluationDataset):
     config: LanguageModelTaskConfig
     config: LanguageModelTaskConfig
 
 

+ 3 - 2
evaluation/tasks.py

@@ -14,7 +14,8 @@ from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceT
 from generation import BaseStrategy, BeamSearchStrategy
 from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
 from .model import ModelForEvaluation
-from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset, SmallGenerationTaskDataset
+from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
+from .dataset import SmallGenerationTaskDataset, SmallMultiChoiceTaskDataset
 from .utils import build_data_loader, gather_result, print_rank_0
 from .utils import build_data_loader, gather_result, print_rank_0
 from .metrics import DEFAULT_METRICS
 from .metrics import DEFAULT_METRICS
 
 
@@ -199,7 +200,7 @@ class MultiChoiceTask(BaseTask, ABC):
         return MultiChoiceTaskConfig
         return MultiChoiceTaskConfig
 
 
     def build_dataset(self, relative_path):
     def build_dataset(self, relative_path):
-        return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
+        return SmallMultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
 
 
     def predict_single_batch(self, batch) -> List[int]:
     def predict_single_batch(self, batch) -> List[int]:
         log_probs = self.model.cond_log_prob(batch)
         log_probs = self.model.cond_log_prob(batch)

+ 2 - 2
evaluation/utils.py

@@ -52,8 +52,8 @@ def gather_result(prediction, total_length, micro_batch_size):
     return prediction
     return prediction
 
 
 
 
-def get_tokenized_input(item, key):
-    if key in item:
+def get_tokenized_input(item, key, no_tokenized=False):
+    if key in item and not no_tokenized:
         return item[key]
         return item[key]
     tokenizer = get_tokenizer()
     tokenizer = get_tokenizer()
     pretokenized_key = key + "_pretokenized"
     pretokenized_key = key + "_pretokenized"

+ 2 - 1
initialize.py

@@ -77,7 +77,8 @@ def initialize_model_and_tokenizer(args):
     # Load checkpoint
     # Load checkpoint
     torch.distributed.barrier()
     torch.distributed.barrier()
     start = time.time()
     start = time.time()
-    load_checkpoint(model, args)
+    if args.load:
+        load_checkpoint(model, args)
     torch.distributed.barrier()
     torch.distributed.barrier()
     if torch.distributed.get_rank() == 0:
     if torch.distributed.get_rank() == 0:
         print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
         print(f"> Checkpoint loaded in {time.time() - start:.1f}s")

+ 3 - 1
tasks/mmlu/mmlu.yaml

@@ -7,4 +7,6 @@ file-pattern:
   social_sciences: "social_sciences/*.json"
   social_sciences: "social_sciences/*.json"
   humanities: "humanities/*.json"
   humanities: "humanities/*.json"
   other: "other/*.json"
   other: "other/*.json"
-micro-batch-size: 1
+no-tokenized: true
+micro-batch-size: 8
+max_seq-length: 896