|
@@ -0,0 +1,118 @@
|
|
|
+import os
|
|
|
+import json
|
|
|
+import re
|
|
|
+from typing import Union, List, Dict, Callable
|
|
|
+from datetime import datetime
|
|
|
+from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
|
|
|
+from evaluation.utils import print_rank_0
|
|
|
+from dataclasses import dataclass
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ChainOfThoughtConfig(GenerationTaskConfig):
|
|
|
+ prompt_path: str = None
|
|
|
+
|
|
|
+
|
|
|
+def read_examples(prompt_path):
|
|
|
+ examples = []
|
|
|
+ item = {"question": None, "answer": None}
|
|
|
+ with open(prompt_path) as file:
|
|
|
+ for line in file:
|
|
|
+ line = line.strip()
|
|
|
+ if line.startswith("Q:"):
|
|
|
+ question = line[3:]
|
|
|
+ item["question"] = question
|
|
|
+ elif line.startswith("A:"):
|
|
|
+ answer = line[3:]
|
|
|
+ item["answer"] = answer
|
|
|
+ examples.append(item)
|
|
|
+ item = {"question": None, "answer": None}
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
+ return examples
|
|
|
+
|
|
|
+
|
|
|
+def build_prompt(examples):
|
|
|
+ prompts = []
|
|
|
+ for item in examples:
|
|
|
+ question, answer = item["question"], item["answer"]
|
|
|
+ prompts.append(f"Question: {question} Answer: {answer}")
|
|
|
+ prompt = " ".join(prompts)
|
|
|
+ return prompt
|
|
|
+
|
|
|
+
|
|
|
+def extract_answer(prediction, task_name):
|
|
|
+ if task_name == "gsm8k":
|
|
|
+ prediction = prediction.lower()
|
|
|
+ match = re.search(r'(?<=the answer is )\d+', prediction)
|
|
|
+ if match:
|
|
|
+ answer = match.group(0)
|
|
|
+ else:
|
|
|
+ answer = ""
|
|
|
+ else:
|
|
|
+ raise NotImplementedError(task_name)
|
|
|
+ return answer
|
|
|
+
|
|
|
+
|
|
|
+class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
+
|
|
|
+ def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
|
|
|
+ self.labeled_examples = read_examples(config.prompt_path)
|
|
|
+ self.labeled_prompt = build_prompt(self.labeled_examples)
|
|
|
+ print_rank_0(self.labeled_prompt)
|
|
|
+ self.printed_count = 0
|
|
|
+ super().__init__(path, config)
|
|
|
+
|
|
|
+ def process_single_item(self, item, **kwargs):
|
|
|
+ question = item["question"]
|
|
|
+ targets = item["answer"].split("####")[1].strip()
|
|
|
+ text = self.labeled_prompt + f" Question: {question} Answer:"
|
|
|
+ text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
|
|
|
+ if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
|
+ text_length = self.config.max_seq_length - self.config.max_gen_length - 2
|
|
|
+ text = text[len(text) - text_length: len(text)]
|
|
|
+ if self.printed_count < 3:
|
|
|
+ print_rank_0(self.tokenizer.detokenize(text))
|
|
|
+ self.printed_count += 1
|
|
|
+ return [{"text": text, "targets": targets, **kwargs}]
|
|
|
+
|
|
|
+
|
|
|
+class ChainOfThoughtTask(GenerationTask):
|
|
|
+ config: ChainOfThoughtConfig
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def config_class(cls):
|
|
|
+ return ChainOfThoughtConfig
|
|
|
+
|
|
|
+ @property
|
|
|
+ def metrics(self) -> Dict[str, Callable]:
|
|
|
+ return {'acuracy': self.extracted_accuracy_metric}
|
|
|
+
|
|
|
+ def extracted_accuracy_metric(self, predictions, examples):
|
|
|
+ count = 0
|
|
|
+ num_predictions = max(len(predictions), 1)
|
|
|
+ assert len(predictions) == len(examples)
|
|
|
+ for prediction, example in zip(predictions, examples):
|
|
|
+ output = self.tokenizer.detokenize(prediction)
|
|
|
+ prediction = extract_answer(output, self.config.name).strip()
|
|
|
+ target = self.tokenizer.detokenize(example["targets"]).strip()
|
|
|
+ count += prediction == target
|
|
|
+ return count * 100.0 / num_predictions
|
|
|
+
|
|
|
+ def build_dataset(self, relative_path, split):
|
|
|
+ return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
+
|
|
|
+ def save_prediction_to_file(self, file, predictions, data):
|
|
|
+ results = []
|
|
|
+ for output, item in zip(predictions, data):
|
|
|
+ output = self.tokenizer.detokenize(output)
|
|
|
+ prediction = extract_answer(output, self.config.name)
|
|
|
+ target = self.tokenizer.detokenize(item["targets"])
|
|
|
+ results.append({"output": output, "prediction": prediction, "answer": target})
|
|
|
+ file_name = file.split(".")[0]
|
|
|
+ if not os.path.exists("outputs"):
|
|
|
+ os.mkdir("outputs")
|
|
|
+ with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
|
|
|
+ '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
|
|
|
+ for result in results:
|
|
|
+ output.write(json.dumps(result) + "\n")
|