Selaa lähdekoodia

Merge CoT branch

Sengxian 2 vuotta sitten
vanhempi
commit
ab6285e319
5 muutettua tiedostoa jossa 18 lisäystä ja 39 poistoa
  1. 1 2
      evaluation/configs.py
  2. 0 12
      evaluation/dataset.py
  3. 8 11
      evaluation/tasks.py
  4. 3 8
      generation/strategies.py
  5. 6 6
      tasks/cot/task.py

+ 1 - 2
evaluation/configs.py

@@ -43,14 +43,13 @@ class MultiChoiceTaskConfig(BaseConfig):
 @dataclass
 class GenerationTaskConfig(BaseConfig):
     module = "evaluation.GenerationTask"
-    metrics: List[str] = field(default_factory=lambda: [])
+    metrics: List[str] = field(default_factory=lambda: ["EM", "F1"])
     sampling_strategy: str = "BaseStrategy"
     num_beams: int = 4
     length_penalty: float = 1.0
     no_repeat_ngram_size: int = 3
     min_gen_length: int = 0
     max_gen_length: int = 128
-    deterministic: bool = False
     end_tokens: List[str] = field(default_factory=lambda: [])
 
 

+ 0 - 12
evaluation/dataset.py

@@ -61,15 +61,6 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
         return None
 
     def process_single_file(self, path):
-        if not path.endswith("jsonl"):
-            try:
-                with open(os.path.join(path), "r", encoding="utf-8") as file:
-                    dataset = json.load(file)
-                for item in dataset:
-                    self.data.extend(self.process_single_item(item))
-                return
-            except json.decoder.JSONDecodeError:
-                pass
         with open(os.path.join(path), "r", encoding="utf-8") as file:
             for line in file:
                 item = json.loads(line)
@@ -171,8 +162,6 @@ class GenerationTaskDataset(EvaluationDataset):
             use_task_mask=self.config.use_task_mask,
             unidirectional=self.config.unidirectional,
         )
-        if "target" in item:
-            sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
         return sample
 
 
@@ -323,7 +312,6 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             unidirectional=self.config.unidirectional,
             use_task_mask=self.config.use_task_mask,
         )
-        sample["label"] = item["label"]
         return sample
 
 

+ 8 - 11
evaluation/tasks.py

@@ -2,7 +2,6 @@ import torch
 import time
 import numpy as np
 import torch.distributed as dist
-from tqdm import tqdm
 
 from typing import Dict, Callable, Type, Tuple, List, Any
 from abc import ABC, abstractmethod
@@ -75,7 +74,7 @@ class BaseTask(ABC):
 
             result_dict_group = {}
             for file in filelist:
-                dataset = self.build_dataset(file, group_name)
+                dataset = self.build_dataset(file)
                 dataloader = build_data_loader(
                     dataset,
                     micro_batch_size=self.config.micro_batch_size,
@@ -85,11 +84,9 @@ class BaseTask(ABC):
                 )
 
                 prediction = []
-                tqdm_wrapper = tqdm if torch.distributed.get_rank() == 0 else lambda x:x
                 with torch.no_grad():
-                    for idx, batch in tqdm_wrapper(enumerate(dataloader)):
-                        p_batch = self.predict_single_batch(batch)
-                        prediction.append(p_batch)
+                    for _, batch in enumerate(dataloader):
+                        prediction.append(self.predict_single_batch(batch))
 
 
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
@@ -161,7 +158,7 @@ class BaseTask(ABC):
         pass
 
     @abstractmethod
-    def build_dataset(self, relative_path: str, split: str) -> EvaluationDataset:
+    def build_dataset(self, relative_path: str) -> EvaluationDataset:
         pass
 
 
@@ -172,7 +169,7 @@ class GenerationTask(BaseTask, ABC):
     def config_class(cls):
         return GenerationTaskConfig
 
-    def build_dataset(self, relative_path, split):
+    def build_dataset(self, relative_path):
         return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
 
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
@@ -185,7 +182,7 @@ class GenerationTask(BaseTask, ABC):
             print_rank_0(f"End tokens {end_tokens}")
         if self.config.sampling_strategy == "BaseStrategy":
             self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
-                                         end_tokens=end_tokens, deterministic=self.config.deterministic)
+                                         end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
                 self.config.micro_batch_size,
@@ -195,7 +192,7 @@ class GenerationTask(BaseTask, ABC):
                 end_tokens=end_tokens,
                 no_repeat_ngram_size=self.config.no_repeat_ngram_size,
                 min_gen_length=self.config.min_gen_length,
-                deterministic=self.config.deterministic,  # For evaluation, we need a determined generation strategy
+                deterministic=True,  # For evaluation, we need a determined generation strategy
             )
         else:
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
@@ -212,7 +209,7 @@ class MultiChoiceTask(BaseTask, ABC):
     def config_class(cls):
         return MultiChoiceTaskConfig
 
-    def build_dataset(self, relative_path, split):
+    def build_dataset(self, relative_path):
         return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
 
     def predict_single_batch(self, batch) -> List[int]:

+ 3 - 8
generation/strategies.py

@@ -4,8 +4,7 @@ import torch.nn.functional as F
 from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
 
 class BaseStrategy:
-    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None,
-                 deterministic=False):
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
         self.batch_size = batch_size
         self.invalid_slices = invalid_slices
         self.temperature = temperature
@@ -15,7 +14,6 @@ class BaseStrategy:
         if end_tokens is None:
             end_tokens = []
         self.end_tokens = end_tokens
-        self.deterministic = deterministic
         self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
     @property
@@ -32,11 +30,8 @@ class BaseStrategy:
             logits[..., invalid_slice] = -65504
 
         logits = top_k_logits(logits, self.topk, self.top_p)
-        if self.deterministic:
-            pred = logits.max(dim=-1)[1]
-        else:
-            probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
-            pred = torch.multinomial(probs, num_samples=1)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        pred = torch.multinomial(probs, num_samples=1)
         for i in range(self.batch_size):
             if i >= batch_size:
                 self._is_done[i] = True

+ 6 - 6
tasks/cot/task.py

@@ -108,10 +108,10 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         self.labeled_examples = read_examples(config.prompt_path)
         self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
                                            prompt_type=config.prompt_type)
-        print_rank_0(self.labeled_prompt)
+        # print_rank_0(self.labeled_prompt)
         self.printed_count = 0
         super().__init__(path, config)
-        print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
+        # print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
 
     def process_single_item(self, item, **kwargs):
         question, targets = item["question"], item["targets"]
@@ -125,9 +125,9 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text = text[len(text) - text_length: len(text)]
-        if self.printed_count < 3:
-            print_rank_0(self.tokenizer.detokenize(text))
-            self.printed_count += 1
+        # if self.printed_count < 3:
+        #     print_rank_0(self.tokenizer.detokenize(text))
+        #     self.printed_count += 1
         return [{"text": text, "targets": targets, **kwargs}]
 
 
@@ -181,7 +181,7 @@ class ChainOfThoughtTask(GenerationTask):
             count += prediction == target
         return count * 100.0 / num_predictions
 
-    def build_dataset(self, relative_path, split):
+    def build_dataset(self, relative_path):
         if self.config.name.startswith("gsm8k"):
             return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
         elif self.config.name.startswith("sports"):