Browse Source

Add saving predictions
Add support for json file
Add end_tokens for generation

duzx16 3 years ago
parent
commit
2cbe915398
5 changed files with 48 additions and 21 deletions
  1. 1 1
      configs/model_glm_130b.sh
  2. 3 1
      evaluation/configs.py
  3. 22 11
      evaluation/dataset.py
  4. 21 7
      evaluation/tasks.py
  5. 1 1
      scripts/evaluate.sh

+ 1 - 1
configs/model_glm_130b.sh

@@ -1,5 +1,5 @@
 MODEL_TYPE="glm-130b"
 MODEL_TYPE="glm-130b"
-CHECKPOINT_PATH="/thudm/workspace/hanyu/SwissArmyTransformer/data/ckpt/iter_0049300"
+CHECKPOINT_PATH="/zhangpai21/checkpoints/glm-130b-sat"
 MP_SIZE=8
 MP_SIZE=8
 MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
 MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
             --num-layers 70 \
             --num-layers 70 \

+ 3 - 1
evaluation/configs.py

@@ -25,6 +25,7 @@ class BaseConfig(YAMLWizard):
     unidirectional: bool = False  # Whether to use unidirectional attention
     unidirectional: bool = False  # Whether to use unidirectional attention
     max_seq_length: int = 2048  # Max sequence length
     max_seq_length: int = 2048  # Max sequence length
     file_pattern: str | Dict[str, str] = "**/*.json*"  # Organize data file in groups
     file_pattern: str | Dict[str, str] = "**/*.json*"  # Organize data file in groups
+    save_prediction: bool = False
 
 
     micro_batch_size: int = 1  # 'gen' task only support mbs = 1 for now
     micro_batch_size: int = 1  # 'gen' task only support mbs = 1 for now
 
 
@@ -41,13 +42,14 @@ class MultiChoiceTaskConfig(BaseConfig):
 @dataclass
 @dataclass
 class GenerationTaskConfig(BaseConfig):
 class GenerationTaskConfig(BaseConfig):
     module = "evaluation.GenerationTask"
     module = "evaluation.GenerationTask"
-    metrics: List[str] = field(default_factory=lambda: ["EM", "F1"])
+    metrics: List[str] = field(default_factory=lambda: [])
     sampling_strategy: str = "BaseStrategy"
     sampling_strategy: str = "BaseStrategy"
     num_beams: int = 4
     num_beams: int = 4
     length_penalty: float = 1.0
     length_penalty: float = 1.0
     no_repeat_ngram_size: int = 3
     no_repeat_ngram_size: int = 3
     min_gen_length: int = 0
     min_gen_length: int = 0
     max_gen_length: int = 128
     max_gen_length: int = 128
+    end_tokens: List[str] = field(default_factory=lambda: [])
 
 
     def __post_init__(self):
     def __post_init__(self):
         assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
         assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"

+ 22 - 11
evaluation/dataset.py

@@ -4,6 +4,7 @@ import json
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
+from typing import List
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from scipy.linalg import block_diag
 from scipy.linalg import block_diag
 
 
@@ -46,10 +47,16 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
         self.gmask_id = tokenizer.get_command("[gMASK]")
         self.gmask_id = tokenizer.get_command("[gMASK]")
 
 
         self.data = []
         self.data = []
-        with open(os.path.join(path), "r", encoding="utf-8") as file:
-            for line in file:
-                item = json.loads(line)
-                self.data.append(self.process_single_item(item))
+        if path.endswith("jsonl"):
+            with open(os.path.join(path), "r", encoding="utf-8") as file:
+                for line in file:
+                    item = json.loads(line)
+                    self.data.extend(self.process_single_item(item))
+        elif path.endswith("json"):
+            with open(os.path.join(path), "r", encoding="utf-8") as file:
+                dataset = json.load(file)
+            for item in dataset:
+                self.data.extend(self.process_single_item(item))
 
 
     @property
     @property
     def has_collate_fn(self) -> bool:
     def has_collate_fn(self) -> bool:
@@ -59,7 +66,7 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
         return None
         return None
 
 
     @abstractmethod
     @abstractmethod
-    def process_single_item(self, item) -> dict:
+    def process_single_item(self, item, **kwargs) -> List[dict]:
         pass
         pass
 
 
     def __len__(self):
     def __len__(self):
@@ -69,12 +76,12 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
 class GenerationTaskDataset(EvaluationDataset):
 class GenerationTaskDataset(EvaluationDataset):
     config: GenerationTaskConfig
     config: GenerationTaskConfig
 
 
-    def process_single_item(self, item):
+    def process_single_item(self, item, **kwargs):
         text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
         text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text = text[len(text) - text_length : len(text)]
             text = text[len(text) - text_length : len(text)]
-        return {"text": text, "targets": targets}
+        return [{"text": text, "targets": targets, **kwargs}]
 
 
     @staticmethod
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
@@ -124,7 +131,8 @@ class GenerationTaskDataset(EvaluationDataset):
             use_task_mask=self.config.use_task_mask,
             use_task_mask=self.config.use_task_mask,
             unidirectional=self.config.unidirectional,
             unidirectional=self.config.unidirectional,
         )
         )
-        sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
+        if "target" in item:
+            sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
         return sample
         return sample
 
 
 
 
@@ -165,7 +173,7 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             "is_single_token": self.is_single_token,
             "is_single_token": self.is_single_token,
         }
         }
 
 
-    def process_single_item(self, item):
+    def process_single_item(self, item, **kwargs):
         text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
         text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
 
 
         tgt_seq_length = sum([len(choice) for choice in choices])
         tgt_seq_length = sum([len(choice) for choice in choices])
@@ -185,11 +193,12 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         if tgt_seq_length != 1:
         if tgt_seq_length != 1:
             self.is_single_token = False
             self.is_single_token = False
 
 
-        return {
+        return [{
             "text": text,
             "text": text,
             "choices": choices,
             "choices": choices,
             "label": label,
             "label": label,
-        }
+            **kwargs
+        }]
 
 
     @staticmethod
     @staticmethod
     def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
     def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
@@ -216,6 +225,8 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
         attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
 
 
         for choice in choices:
         for choice in choices:
+            if not choice:
+                choice = [tokenizer.get_command('eop')]
             position_id = np.concatenate(
             position_id = np.concatenate(
                 (
                 (
                     position_id,
                     position_id,

+ 21 - 7
evaluation/tasks.py

@@ -2,6 +2,7 @@ import torch
 import time
 import time
 import numpy as np
 import numpy as np
 import torch.distributed as dist
 import torch.distributed as dist
+from tqdm import tqdm
 
 
 from typing import Dict, Callable, Type, Tuple, List, Any
 from typing import Dict, Callable, Type, Tuple, List, Any
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
@@ -42,6 +43,10 @@ class BaseTask(ABC):
 
 
         self.file_groups = self.get_file_groups()
         self.file_groups = self.get_file_groups()
         self.verbose = dist.get_rank() == 0
         self.verbose = dist.get_rank() == 0
+        self.save_prediction = config.save_prediction
+
+    def save_prediction_to_file(self, file, prediction, data):
+        pass
 
 
     def get_file_groups(self):
     def get_file_groups(self):
         pattern_group = {}
         pattern_group = {}
@@ -71,7 +76,7 @@ class BaseTask(ABC):
 
 
             result_dict_group = {}
             result_dict_group = {}
             for file in filelist:
             for file in filelist:
-                dataset = self.build_dataset(file)
+                dataset = self.build_dataset(file, group_name)
                 dataloader = build_data_loader(
                 dataloader = build_data_loader(
                     dataset,
                     dataset,
                     micro_batch_size=self.config.micro_batch_size,
                     micro_batch_size=self.config.micro_batch_size,
@@ -81,13 +86,18 @@ class BaseTask(ABC):
                 )
                 )
 
 
                 prediction = []
                 prediction = []
+                tqdm_wrapper = tqdm if torch.distributed.get_rank() == 0 else lambda x:x
                 with torch.no_grad():
                 with torch.no_grad():
-                    for _, batch in enumerate(dataloader):
-                        prediction.append(self.predict_single_batch(batch))
+                    for idx, batch in tqdm_wrapper(enumerate(dataloader)):
+                        p_batch = self.predict_single_batch(batch)
+                        prediction.append(p_batch)
+
 
 
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
                 result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
                 result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
                 result_dict_group[file] = (result_dict, len(dataset))
                 result_dict_group[file] = (result_dict, len(dataset))
+                if torch.distributed.get_rank() == 0 and self.save_prediction:
+                    self.save_prediction_to_file(file, prediction, dataset.data)
 
 
                 if self.verbose:
                 if self.verbose:
                     self.report_single_metrics(file, result_dict)
                     self.report_single_metrics(file, result_dict)
@@ -152,7 +162,7 @@ class BaseTask(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def build_dataset(self, relative_path: str) -> EvaluationDataset:
+    def build_dataset(self, relative_path: str, split: str) -> EvaluationDataset:
         pass
         pass
 
 
 
 
@@ -163,13 +173,17 @@ class GenerationTask(BaseTask, ABC):
     def config_class(cls):
     def config_class(cls):
         return GenerationTaskConfig
         return GenerationTaskConfig
 
 
-    def build_dataset(self, relative_path):
+    def build_dataset(self, relative_path, split):
         return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
         return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
 
 
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
         super(GenerationTask, self).__init__(model, tokenizer, config)
         super(GenerationTask, self).__init__(model, tokenizer, config)
 
 
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
+        if self.config.end_tokens:
+            for token in self.config.end_tokens:
+                end_tokens.append(self.tokenizer.tokenize(token)[-1])
+            print_rank_0(f"End tokens {end_tokens}")
         if self.config.sampling_strategy == "BaseStrategy":
         if self.config.sampling_strategy == "BaseStrategy":
             self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
             self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
         elif self.config.sampling_strategy == "BeamSearchStrategy":
@@ -180,7 +194,7 @@ class GenerationTask(BaseTask, ABC):
                 end_tokens=end_tokens,
                 end_tokens=end_tokens,
                 no_repeat_ngram_size=self.config.no_repeat_ngram_size,
                 no_repeat_ngram_size=self.config.no_repeat_ngram_size,
                 min_gen_length=self.config.min_gen_length,
                 min_gen_length=self.config.min_gen_length,
-                deterministic=True,  # For evaluation, we need a determined generation strategy
+                deterministic=False,  # For evaluation, we need a determined generation strategy
             )
             )
         else:
         else:
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
@@ -199,7 +213,7 @@ class MultiChoiceTask(BaseTask, ABC):
     def config_class(cls):
     def config_class(cls):
         return MultiChoiceTaskConfig
         return MultiChoiceTaskConfig
 
 
-    def build_dataset(self, relative_path):
+    def build_dataset(self, relative_path, split):
         return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
         return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
 
 
     def predict_single_batch(self, batch) -> List[int]:
     def predict_single_batch(self, batch) -> List[int]:

+ 1 - 1
scripts/evaluate.sh

@@ -6,7 +6,7 @@ main_dir=$(dirname $script_dir)
 
 
 source "${main_dir}/configs/model_glm_130b.sh"
 source "${main_dir}/configs/model_glm_130b.sh"
 
 
-DATA_PATH="<your evaluation dataset base directory>"
+DATA_PATH="/zhangpai21/workspace/zxdu"
 
 
 ARGS="${main_dir}/evaluate.py \
 ARGS="${main_dir}/evaluate.py \
        --mode inference \
        --mode inference \