123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- import os
- import json
- import re
- from typing import Union, List, Dict, Callable
- from datetime import datetime
- from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
- from evaluation.utils import print_rank_0
- from dataclasses import dataclass
- @dataclass
- class ChainOfThoughtConfig(GenerationTaskConfig):
- prompt_path: str = None
- def read_examples(prompt_path):
- examples = []
- item = {"question": None, "answer": None}
- with open(prompt_path) as file:
- for line in file:
- line = line.strip()
- if line.startswith("Q:"):
- question = line[3:]
- item["question"] = question
- elif line.startswith("A:"):
- answer = line[3:]
- item["answer"] = answer
- examples.append(item)
- item = {"question": None, "answer": None}
- else:
- raise NotImplementedError
- return examples
- def build_prompt(examples):
- prompts = []
- for item in examples:
- question, answer = item["question"], item["answer"]
- prompts.append(f"Question: {question} Answer: {answer}")
- prompt = " ".join(prompts)
- return prompt
- def extract_answer(prediction, task_name):
- if task_name == "gsm8k":
- prediction = prediction.lower()
- match = re.search(r'(?<=the answer is )\d+', prediction)
- if match:
- answer = match.group(0)
- else:
- answer = ""
- else:
- raise NotImplementedError(task_name)
- return answer
- class ChainOfThoughtDataset(GenerationTaskDataset):
- def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
- self.labeled_examples = read_examples(config.prompt_path)
- self.labeled_prompt = build_prompt(self.labeled_examples)
- print_rank_0(self.labeled_prompt)
- self.printed_count = 0
- super().__init__(path, config)
- def process_single_item(self, item, **kwargs):
- question = item["question"]
- targets = item["answer"].split("####")[1].strip()
- text = self.labeled_prompt + f" Question: {question} Answer:"
- text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
- if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
- text_length = self.config.max_seq_length - self.config.max_gen_length - 2
- text = text[len(text) - text_length: len(text)]
- if self.printed_count < 3:
- print_rank_0(self.tokenizer.detokenize(text))
- self.printed_count += 1
- return [{"text": text, "targets": targets, **kwargs}]
- class ChainOfThoughtTask(GenerationTask):
- config: ChainOfThoughtConfig
- @classmethod
- def config_class(cls):
- return ChainOfThoughtConfig
- @property
- def metrics(self) -> Dict[str, Callable]:
- return {'acuracy': self.extracted_accuracy_metric}
- def extracted_accuracy_metric(self, predictions, examples):
- count = 0
- num_predictions = max(len(predictions), 1)
- assert len(predictions) == len(examples)
- for prediction, example in zip(predictions, examples):
- output = self.tokenizer.detokenize(prediction)
- prediction = extract_answer(output, self.config.name).strip()
- target = self.tokenizer.detokenize(example["targets"]).strip()
- count += prediction == target
- return count * 100.0 / num_predictions
- def build_dataset(self, relative_path, split):
- return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
- def save_prediction_to_file(self, file, predictions, data):
- results = []
- for output, item in zip(predictions, data):
- output = self.tokenizer.detokenize(output)
- prediction = extract_answer(output, self.config.name)
- target = self.tokenizer.detokenize(item["targets"])
- results.append({"output": output, "prediction": prediction, "answer": target})
- file_name = file.split(".")[0]
- if not os.path.exists("outputs"):
- os.mkdir("outputs")
- with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
- '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
- for result in results:
- output.write(json.dumps(result) + "\n")
|