Browse Source

Merge CoT branch

Sengxian 2 years ago
parent
commit
ab6285e319
5 changed files with 18 additions and 39 deletions
  1. 1 2
      evaluation/configs.py
  2. 0 12
      evaluation/dataset.py
  3. 8 11
      evaluation/tasks.py
  4. 3 8
      generation/strategies.py
  5. 6 6
      tasks/cot/task.py

+ 1 - 2
evaluation/configs.py

@@ -43,14 +43,13 @@ class MultiChoiceTaskConfig(BaseConfig):
 @dataclass
 @dataclass
 class GenerationTaskConfig(BaseConfig):
 class GenerationTaskConfig(BaseConfig):
     module = "evaluation.GenerationTask"
     module = "evaluation.GenerationTask"
-    metrics: List[str] = field(default_factory=lambda: [])
+    metrics: List[str] = field(default_factory=lambda: ["EM", "F1"])
     sampling_strategy: str = "BaseStrategy"
     sampling_strategy: str = "BaseStrategy"
     num_beams: int = 4
     num_beams: int = 4
     length_penalty: float = 1.0
     length_penalty: float = 1.0
     no_repeat_ngram_size: int = 3
     no_repeat_ngram_size: int = 3
     min_gen_length: int = 0
     min_gen_length: int = 0
     max_gen_length: int = 128
     max_gen_length: int = 128
-    deterministic: bool = False
     end_tokens: List[str] = field(default_factory=lambda: [])
     end_tokens: List[str] = field(default_factory=lambda: [])
 
 
 
 

+ 0 - 12
evaluation/dataset.py

@@ -61,15 +61,6 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
         return None
         return None
 
 
     def process_single_file(self, path):
     def process_single_file(self, path):
-        if not path.endswith("jsonl"):
-            try:
-                with open(os.path.join(path), "r", encoding="utf-8") as file:
-                    dataset = json.load(file)
-                for item in dataset:
-                    self.data.extend(self.process_single_item(item))
-                return
-            except json.decoder.JSONDecodeError:
-                pass
         with open(os.path.join(path), "r", encoding="utf-8") as file:
         with open(os.path.join(path), "r", encoding="utf-8") as file:
             for line in file:
             for line in file:
                 item = json.loads(line)
                 item = json.loads(line)
@@ -171,8 +162,6 @@ class GenerationTaskDataset(EvaluationDataset):
             use_task_mask=self.config.use_task_mask,
             use_task_mask=self.config.use_task_mask,
             unidirectional=self.config.unidirectional,
             unidirectional=self.config.unidirectional,
         )
         )
-        if "target" in item:
-            sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
         return sample
         return sample
 
 
 
 
@@ -323,7 +312,6 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             unidirectional=self.config.unidirectional,
             unidirectional=self.config.unidirectional,
             use_task_mask=self.config.use_task_mask,
             use_task_mask=self.config.use_task_mask,
         )
         )
-        sample["label"] = item["label"]
         return sample
         return sample
 
 
 
 

+ 8 - 11
evaluation/tasks.py

@@ -2,7 +2,6 @@ import torch
 import time
 import time
 import numpy as np
 import numpy as np
 import torch.distributed as dist
 import torch.distributed as dist
-from tqdm import tqdm
 
 
 from typing import Dict, Callable, Type, Tuple, List, Any
 from typing import Dict, Callable, Type, Tuple, List, Any
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
@@ -75,7 +74,7 @@ class BaseTask(ABC):
 
 
             result_dict_group = {}
             result_dict_group = {}
             for file in filelist:
             for file in filelist:
-                dataset = self.build_dataset(file, group_name)
+                dataset = self.build_dataset(file)
                 dataloader = build_data_loader(
                 dataloader = build_data_loader(
                     dataset,
                     dataset,
                     micro_batch_size=self.config.micro_batch_size,
                     micro_batch_size=self.config.micro_batch_size,
@@ -85,11 +84,9 @@ class BaseTask(ABC):
                 )
                 )
 
 
                 prediction = []
                 prediction = []
-                tqdm_wrapper = tqdm if torch.distributed.get_rank() == 0 else lambda x:x
                 with torch.no_grad():
                 with torch.no_grad():
-                    for idx, batch in tqdm_wrapper(enumerate(dataloader)):
-                        p_batch = self.predict_single_batch(batch)
-                        prediction.append(p_batch)
+                    for _, batch in enumerate(dataloader):
+                        prediction.append(self.predict_single_batch(batch))
 
 
 
 
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
@@ -161,7 +158,7 @@ class BaseTask(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def build_dataset(self, relative_path: str, split: str) -> EvaluationDataset:
+    def build_dataset(self, relative_path: str) -> EvaluationDataset:
         pass
         pass
 
 
 
 
@@ -172,7 +169,7 @@ class GenerationTask(BaseTask, ABC):
     def config_class(cls):
     def config_class(cls):
         return GenerationTaskConfig
         return GenerationTaskConfig
 
 
-    def build_dataset(self, relative_path, split):
+    def build_dataset(self, relative_path):
         return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
         return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
 
 
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
@@ -185,7 +182,7 @@ class GenerationTask(BaseTask, ABC):
             print_rank_0(f"End tokens {end_tokens}")
             print_rank_0(f"End tokens {end_tokens}")
         if self.config.sampling_strategy == "BaseStrategy":
         if self.config.sampling_strategy == "BaseStrategy":
             self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
             self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
-                                         end_tokens=end_tokens, deterministic=self.config.deterministic)
+                                         end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
             self.strategy = BeamSearchStrategy(
                 self.config.micro_batch_size,
                 self.config.micro_batch_size,
@@ -195,7 +192,7 @@ class GenerationTask(BaseTask, ABC):
                 end_tokens=end_tokens,
                 end_tokens=end_tokens,
                 no_repeat_ngram_size=self.config.no_repeat_ngram_size,
                 no_repeat_ngram_size=self.config.no_repeat_ngram_size,
                 min_gen_length=self.config.min_gen_length,
                 min_gen_length=self.config.min_gen_length,
-                deterministic=self.config.deterministic,  # For evaluation, we need a determined generation strategy
+                deterministic=True,  # For evaluation, we need a determined generation strategy
             )
             )
         else:
         else:
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
@@ -212,7 +209,7 @@ class MultiChoiceTask(BaseTask, ABC):
     def config_class(cls):
     def config_class(cls):
         return MultiChoiceTaskConfig
         return MultiChoiceTaskConfig
 
 
-    def build_dataset(self, relative_path, split):
+    def build_dataset(self, relative_path):
         return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
         return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
 
 
     def predict_single_batch(self, batch) -> List[int]:
     def predict_single_batch(self, batch) -> List[int]:

+ 3 - 8
generation/strategies.py

@@ -4,8 +4,7 @@ import torch.nn.functional as F
 from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
 from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
 
 
 class BaseStrategy:
 class BaseStrategy:
-    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None,
-                 deterministic=False):
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
         self.batch_size = batch_size
         self.batch_size = batch_size
         self.invalid_slices = invalid_slices
         self.invalid_slices = invalid_slices
         self.temperature = temperature
         self.temperature = temperature
@@ -15,7 +14,6 @@ class BaseStrategy:
         if end_tokens is None:
         if end_tokens is None:
             end_tokens = []
             end_tokens = []
         self.end_tokens = end_tokens
         self.end_tokens = end_tokens
-        self.deterministic = deterministic
         self._is_done = np.zeros(self.batch_size, dtype=np.bool)
         self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
 
     @property
     @property
@@ -32,11 +30,8 @@ class BaseStrategy:
             logits[..., invalid_slice] = -65504
             logits[..., invalid_slice] = -65504
 
 
         logits = top_k_logits(logits, self.topk, self.top_p)
         logits = top_k_logits(logits, self.topk, self.top_p)
-        if self.deterministic:
-            pred = logits.max(dim=-1)[1]
-        else:
-            probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
-            pred = torch.multinomial(probs, num_samples=1)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        pred = torch.multinomial(probs, num_samples=1)
         for i in range(self.batch_size):
         for i in range(self.batch_size):
             if i >= batch_size:
             if i >= batch_size:
                 self._is_done[i] = True
                 self._is_done[i] = True

+ 6 - 6
tasks/cot/task.py

@@ -108,10 +108,10 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         self.labeled_examples = read_examples(config.prompt_path)
         self.labeled_examples = read_examples(config.prompt_path)
         self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
         self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
                                            prompt_type=config.prompt_type)
                                            prompt_type=config.prompt_type)
-        print_rank_0(self.labeled_prompt)
+        # print_rank_0(self.labeled_prompt)
         self.printed_count = 0
         self.printed_count = 0
         super().__init__(path, config)
         super().__init__(path, config)
-        print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
+        # print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
 
 
     def process_single_item(self, item, **kwargs):
     def process_single_item(self, item, **kwargs):
         question, targets = item["question"], item["targets"]
         question, targets = item["question"], item["targets"]
@@ -125,9 +125,9 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text = text[len(text) - text_length: len(text)]
             text = text[len(text) - text_length: len(text)]
-        if self.printed_count < 3:
-            print_rank_0(self.tokenizer.detokenize(text))
-            self.printed_count += 1
+        # if self.printed_count < 3:
+        #     print_rank_0(self.tokenizer.detokenize(text))
+        #     self.printed_count += 1
         return [{"text": text, "targets": targets, **kwargs}]
         return [{"text": text, "targets": targets, **kwargs}]
 
 
 
 
@@ -181,7 +181,7 @@ class ChainOfThoughtTask(GenerationTask):
             count += prediction == target
             count += prediction == target
         return count * 100.0 / num_predictions
         return count * 100.0 / num_predictions
 
 
-    def build_dataset(self, relative_path, split):
+    def build_dataset(self, relative_path):
         if self.config.name.startswith("gsm8k"):
         if self.config.name.startswith("gsm8k"):
             return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
             return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
         elif self.config.name.startswith("sports"):
         elif self.config.name.startswith("sports"):