|
@@ -11,6 +11,7 @@ from dataclasses import dataclass
|
|
|
@dataclass
|
|
|
class ChainOfThoughtConfig(GenerationTaskConfig):
|
|
|
prompt_path: str = None
|
|
|
+ chain_of_thought: bool = True
|
|
|
|
|
|
|
|
|
def read_examples(prompt_path):
|
|
@@ -32,23 +33,40 @@ def read_examples(prompt_path):
|
|
|
return examples
|
|
|
|
|
|
|
|
|
-def build_prompt(examples):
|
|
|
+def build_prompt(examples, task_name, chain_of_thought=True):
|
|
|
prompts = []
|
|
|
for item in examples:
|
|
|
question, answer = item["question"], item["answer"]
|
|
|
+ if not chain_of_thought:
|
|
|
+ answer = extract_answer(answer, task_name)
|
|
|
prompts.append(f"Question: {question} Answer: {answer}")
|
|
|
prompt = " ".join(prompts)
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
-def extract_answer(prediction, task_name):
|
|
|
- if task_name == "gsm8k":
|
|
|
+def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
+ if task_name.startswith("gsm8k"):
|
|
|
prediction = prediction.lower()
|
|
|
- match = re.search(r'(?<=the answer is )\d+', prediction)
|
|
|
+ if chain_of_thought:
|
|
|
+ pattern = r'(?<=the answer is )\d+'
|
|
|
+ else:
|
|
|
+ pattern = r'\d+'
|
|
|
+ match = re.search(pattern, prediction)
|
|
|
if match:
|
|
|
answer = match.group(0)
|
|
|
else:
|
|
|
answer = ""
|
|
|
+ elif task_name.startswith("sports"):
|
|
|
+ prediction = prediction.lower()
|
|
|
+ if chain_of_thought:
|
|
|
+ pattern = r'(?<=the answer is )(yes|no)'
|
|
|
+ else:
|
|
|
+ pattern = r'yes|no'
|
|
|
+ match = re.search(pattern, prediction)
|
|
|
+ if match:
|
|
|
+ answer = match.group(0)
|
|
|
+ else:
|
|
|
+ answer = "no"
|
|
|
else:
|
|
|
raise NotImplementedError(task_name)
|
|
|
return answer
|
|
@@ -58,14 +76,13 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
|
|
|
def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
|
|
|
self.labeled_examples = read_examples(config.prompt_path)
|
|
|
- self.labeled_prompt = build_prompt(self.labeled_examples)
|
|
|
+ self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought)
|
|
|
print_rank_0(self.labeled_prompt)
|
|
|
self.printed_count = 0
|
|
|
super().__init__(path, config)
|
|
|
|
|
|
def process_single_item(self, item, **kwargs):
|
|
|
- question = item["question"]
|
|
|
- targets = item["answer"].split("####")[1].strip()
|
|
|
+ question, targets = item["question"], item["targets"]
|
|
|
text = self.labeled_prompt + f" Question: {question} Answer:"
|
|
|
text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
|
|
|
if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
@@ -77,6 +94,26 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
return [{"text": text, "targets": targets, **kwargs}]
|
|
|
|
|
|
|
|
|
+class GSM8KDataset(ChainOfThoughtDataset):
|
|
|
+ def process_single_item(self, item, **kwargs):
|
|
|
+ item["targets"] = item["answer"].split("####")[1].strip()
|
|
|
+ return super().process_single_item(item)
|
|
|
+
|
|
|
+
|
|
|
+class SportsDataset(ChainOfThoughtDataset):
|
|
|
+ def process_single_file(self, path):
|
|
|
+ with open(path) as file:
|
|
|
+ dataset = json.load(file)
|
|
|
+ for item in dataset["examples"]:
|
|
|
+ sentence = item["input"]
|
|
|
+ item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
|
|
|
+ if item["target_scores"]["plausible"] == 1:
|
|
|
+ item["targets"] = "yes"
|
|
|
+ else:
|
|
|
+ item["targets"] = "no"
|
|
|
+ self.data.extend(self.process_single_item(item))
|
|
|
+
|
|
|
+
|
|
|
class ChainOfThoughtTask(GenerationTask):
|
|
|
config: ChainOfThoughtConfig
|
|
|
|
|
@@ -94,19 +131,24 @@ class ChainOfThoughtTask(GenerationTask):
|
|
|
assert len(predictions) == len(examples)
|
|
|
for prediction, example in zip(predictions, examples):
|
|
|
output = self.tokenizer.detokenize(prediction)
|
|
|
- prediction = extract_answer(output, self.config.name).strip()
|
|
|
+ prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
|
|
|
target = self.tokenizer.detokenize(example["targets"]).strip()
|
|
|
count += prediction == target
|
|
|
return count * 100.0 / num_predictions
|
|
|
|
|
|
def build_dataset(self, relative_path, split):
|
|
|
- return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
+ if self.config.name.startswith("gsm8k"):
|
|
|
+ return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
+ elif self.config.name.startswith("sports"):
|
|
|
+ return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
def save_prediction_to_file(self, file, predictions, data):
|
|
|
results = []
|
|
|
for output, item in zip(predictions, data):
|
|
|
output = self.tokenizer.detokenize(output)
|
|
|
- prediction = extract_answer(output, self.config.name)
|
|
|
+ prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
|
|
|
target = self.tokenizer.detokenize(item["targets"])
|
|
|
results.append({"output": output, "prediction": prediction, "answer": target})
|
|
|
file_name = file.split(".")[0]
|