Răsfoiți Sursa

Add DATE dataset for CoT

Sengxian 2 ani în urmă
părinte
comite
60bb22bb13
4 a modificat fișierele cu 43 adăugiri și 17 ștergeri
  1. 0 1
      tasks/cot/lastletter.yaml
  2. 0 1
      tasks/cot/reverse.yaml
  3. 0 1
      tasks/cot/sports.yaml
  4. 43 14
      tasks/cot/task.py

+ 0 - 1
tasks/cot/lastletter.yaml

@@ -6,7 +6,6 @@ file-pattern:
   test: "lastletter.jsonl"
 sampling_strategy: "BaseStrategy"
 prompt_path: "tasks/cot/lastletter_prompt.txt"
-deterministic: true
 unidirectional: true
 max_gen_length: 64
 use_task_mask: true

+ 0 - 1
tasks/cot/reverse.yaml

@@ -6,7 +6,6 @@ file-pattern:
   test: "reverse_5.jsonl"
 sampling_strategy: "BaseStrategy"
 prompt_path: "tasks/cot/reverse_prompt.txt"
-deterministic: true
 max_gen_length: 128
 use_task_mask: true
 save_prediction: true

+ 0 - 1
tasks/cot/sports.yaml

@@ -6,7 +6,6 @@ file-pattern:
   test: "sports.json"
 sampling_strategy: "BaseStrategy"
 prompt_path: "tasks/cot/sports_prompt.txt"
-deterministic: true
 unidirectional: true
 max_gen_length: 128
 use_task_mask: true

+ 43 - 14
tasks/cot/task.py

@@ -55,9 +55,9 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
     if task_name.startswith("gsm8k"):
         prediction = prediction.lower()
         if chain_of_thought:
-            pattern = r'(?<=the answer is )\d+'
+            pattern = r"(?<=the answer is )\d+"
         else:
-            pattern = r'\d+'
+            pattern = r"\d+"
         match = re.search(pattern, prediction)
         if match:
             answer = match.group(0)
@@ -66,9 +66,9 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
     elif task_name.startswith("sports") or task_name.startswith("coinflip"):
         prediction = prediction.lower()
         if chain_of_thought:
-            pattern = r'(?<=the answer is )(yes|no)'
+            pattern = r"(?<=the answer is )(yes|no)"
         else:
-            pattern = r'yes|no'
+            pattern = r"yes|no"
         match = re.search(pattern, prediction)
         if match:
             answer = match.group(0)
@@ -77,9 +77,9 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
     elif task_name.startswith("lastletter"):
         prediction = prediction.lower()
         if chain_of_thought:
-            pattern = r'(?<=the answer is )[a-z]+'
+            pattern = r"(?<=the answer is )[a-z]+"
         else:
-            pattern = r'[a-z]+'
+            pattern = r"[a-z]+"
         match = re.search(pattern, prediction)
         if match:
             answer = match.group(0)
@@ -90,7 +90,19 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
         if chain_of_thought:
             pattern = r'(?<=the answer is ")[a-z|,| ]+'
         else:
-            pattern = r'[a-z|,| ]+'
+            pattern = r"[a-z|,| ]+"
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = ""
+    elif task_name.startswith("date"):
+        prediction = prediction.lower()
+        date_regex = r"(((0[0-9])|(1[012]))\/((0[1-9])|([12][0-9])|(3[01]))\/((20[012]\d|19\d\d)|(1\d|2[0123])))"
+        if chain_of_thought:
+            pattern = r"(?<=the answer is )" + date_regex
+        else:
+            pattern = date_regex
         match = re.search(pattern, prediction)
         if match:
             answer = match.group(0)
@@ -106,8 +118,9 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
 
     def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
         self.labeled_examples = read_examples(config.prompt_path)
-        self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
-                                           prompt_type=config.prompt_type)
+        self.labeled_prompt = build_prompt(
+            self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought, prompt_type=config.prompt_type
+        )
         # print_rank_0(self.labeled_prompt)
         self.printed_count = 0
         super().__init__(path, config)
@@ -124,7 +137,7 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
-            text = text[len(text) - text_length: len(text)]
+            text = text[len(text) - text_length : len(text)]
         # if self.printed_count < 3:
         #     print_rank_0(self.tokenizer.detokenize(text))
         #     self.printed_count += 1
@@ -143,7 +156,7 @@ class SportsDataset(ChainOfThoughtDataset):
             dataset = json.load(file)
         for item in dataset["examples"]:
             sentence = item["input"]
-            item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
+            item["question"] = f'Is the following sentence plausible? "{sentence}."'
             if item["target_scores"]["plausible"] == 1:
                 item["targets"] = "yes"
             else:
@@ -151,6 +164,19 @@ class SportsDataset(ChainOfThoughtDataset):
             self.data.extend(self.process_single_item(item))
 
 
+class DateDataset(ChainOfThoughtDataset):
+    def process_single_file(self, path):
+        with open(path) as file:
+            dataset = json.load(file)
+        for item in dataset["examples"]:
+            sentence = item["input"]
+            item["question"] = sentence
+            for key, value in item["target_scores"].items():
+                if value == 1:
+                    item["targets"] = key
+            self.data.extend(self.process_single_item(item))
+
+
 class LastLetterDataset(ChainOfThoughtDataset):
     def process_single_item(self, item, **kwargs):
         first_name, last_name = item["first_name"], item["last_name"]
@@ -168,7 +194,7 @@ class ChainOfThoughtTask(GenerationTask):
 
     @property
     def metrics(self) -> Dict[str, Callable]:
-        return {'acuracy': self.extracted_accuracy_metric}
+        return {"Accuracy": self.extracted_accuracy_metric}
 
     def extracted_accuracy_metric(self, predictions, examples):
         count = 0
@@ -190,6 +216,8 @@ class ChainOfThoughtTask(GenerationTask):
             return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
         elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
             return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("date"):
+            return DateDataset(os.path.join(self.config.path, relative_path), self.config)
         else:
             raise NotImplementedError
 
@@ -203,7 +231,8 @@ class ChainOfThoughtTask(GenerationTask):
         file_name = file.split(".")[0]
         if not os.path.exists("outputs"):
             os.mkdir("outputs")
-        with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
-                '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
+        with open(
+            "outputs/" + self.config.name + "_" + datetime.now().strftime("%m-%d-%H-%M_") + file_name + ".json", "w"
+        ) as output:
             for result in results:
                 output.write(json.dumps(result) + "\n")