|
@@ -11,6 +11,8 @@ from dataclasses import dataclass
|
|
@dataclass
|
|
@dataclass
|
|
class ChainOfThoughtConfig(GenerationTaskConfig):
|
|
class ChainOfThoughtConfig(GenerationTaskConfig):
|
|
prompt_path: str = None
|
|
prompt_path: str = None
|
|
|
|
+ chain_of_thought: bool = True
|
|
|
|
+ prompt_type: str = None
|
|
|
|
|
|
|
|
|
|
def read_examples(prompt_path):
|
|
def read_examples(prompt_path):
|
|
@@ -32,19 +34,64 @@ def read_examples(prompt_path):
|
|
return examples
|
|
return examples
|
|
|
|
|
|
|
|
|
|
-def build_prompt(examples):
|
|
|
|
|
|
+def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
|
|
prompts = []
|
|
prompts = []
|
|
- for item in examples:
|
|
|
|
|
|
+ for i, item in enumerate(examples):
|
|
question, answer = item["question"], item["answer"]
|
|
question, answer = item["question"], item["answer"]
|
|
- prompts.append(f"Question: {question} Answer: {answer}")
|
|
|
|
- prompt = " ".join(prompts)
|
|
|
|
|
|
+ if not chain_of_thought:
|
|
|
|
+ answer = extract_answer(answer, task_name)
|
|
|
|
+ if prompt_type == "number":
|
|
|
|
+ prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
|
|
|
|
+ else:
|
|
|
|
+ prompts.append(f"Question: {question} Answer: {answer}")
|
|
|
|
+ if prompt_type == "return":
|
|
|
|
+ prompt = " <n>".join(prompts)
|
|
|
|
+ else:
|
|
|
|
+ prompt = " ".join(prompts)
|
|
return prompt
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
-def extract_answer(prediction, task_name):
|
|
|
|
- if task_name == "gsm8k":
|
|
|
|
|
|
+def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
|
+ if task_name.startswith("gsm8k"):
|
|
|
|
+ prediction = prediction.lower()
|
|
|
|
+ if chain_of_thought:
|
|
|
|
+ pattern = r'(?<=the answer is )\d+'
|
|
|
|
+ else:
|
|
|
|
+ pattern = r'\d+'
|
|
|
|
+ match = re.search(pattern, prediction)
|
|
|
|
+ if match:
|
|
|
|
+ answer = match.group(0)
|
|
|
|
+ else:
|
|
|
|
+ answer = ""
|
|
|
|
+ elif task_name.startswith("sports") or task_name.startswith("coinflip"):
|
|
|
|
+ prediction = prediction.lower()
|
|
|
|
+ if chain_of_thought:
|
|
|
|
+ pattern = r'(?<=the answer is )(yes|no)'
|
|
|
|
+ else:
|
|
|
|
+ pattern = r'yes|no'
|
|
|
|
+ match = re.search(pattern, prediction)
|
|
|
|
+ if match:
|
|
|
|
+ answer = match.group(0)
|
|
|
|
+ else:
|
|
|
|
+ answer = "no"
|
|
|
|
+ elif task_name.startswith("lastletter"):
|
|
|
|
+ prediction = prediction.lower()
|
|
|
|
+ if chain_of_thought:
|
|
|
|
+ pattern = r'(?<=the answer is )[a-z]+'
|
|
|
|
+ else:
|
|
|
|
+ pattern = r'[a-z]+'
|
|
|
|
+ match = re.search(pattern, prediction)
|
|
|
|
+ if match:
|
|
|
|
+ answer = match.group(0)
|
|
|
|
+ else:
|
|
|
|
+ answer = ""
|
|
|
|
+ elif task_name.startswith("reverse"):
|
|
prediction = prediction.lower()
|
|
prediction = prediction.lower()
|
|
- match = re.search(r'(?<=the answer is )\d+', prediction)
|
|
|
|
|
|
+ if chain_of_thought:
|
|
|
|
+ pattern = r'(?<=the answer is ")[a-z|,| ]+'
|
|
|
|
+ else:
|
|
|
|
+ pattern = r'[a-z|,| ]+'
|
|
|
|
+ match = re.search(pattern, prediction)
|
|
if match:
|
|
if match:
|
|
answer = match.group(0)
|
|
answer = match.group(0)
|
|
else:
|
|
else:
|
|
@@ -55,18 +102,25 @@ def extract_answer(prediction, task_name):
|
|
|
|
|
|
|
|
|
|
class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
|
+ config: ChainOfThoughtConfig
|
|
|
|
|
|
def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
|
|
def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
|
|
self.labeled_examples = read_examples(config.prompt_path)
|
|
self.labeled_examples = read_examples(config.prompt_path)
|
|
- self.labeled_prompt = build_prompt(self.labeled_examples)
|
|
|
|
|
|
+ self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
|
|
|
|
+ prompt_type=config.prompt_type)
|
|
print_rank_0(self.labeled_prompt)
|
|
print_rank_0(self.labeled_prompt)
|
|
self.printed_count = 0
|
|
self.printed_count = 0
|
|
super().__init__(path, config)
|
|
super().__init__(path, config)
|
|
|
|
+ print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
|
|
|
|
|
|
def process_single_item(self, item, **kwargs):
|
|
def process_single_item(self, item, **kwargs):
|
|
- question = item["question"]
|
|
|
|
- targets = item["answer"].split("####")[1].strip()
|
|
|
|
- text = self.labeled_prompt + f" Question: {question} Answer:"
|
|
|
|
|
|
+ question, targets = item["question"], item["targets"]
|
|
|
|
+ if self.config.prompt_type == "number":
|
|
|
|
+ text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
|
|
|
|
+ elif self.config.prompt_type == "return":
|
|
|
|
+ text = self.labeled_prompt + f" <n>Question: {question} Answer:"
|
|
|
|
+ else:
|
|
|
|
+ text = self.labeled_prompt + f" Question: {question} Answer:"
|
|
text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
|
|
text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
|
|
if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
text_length = self.config.max_seq_length - self.config.max_gen_length - 2
|
|
text_length = self.config.max_seq_length - self.config.max_gen_length - 2
|
|
@@ -77,6 +131,34 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
return [{"text": text, "targets": targets, **kwargs}]
|
|
return [{"text": text, "targets": targets, **kwargs}]
|
|
|
|
|
|
|
|
|
|
|
|
+class GSM8KDataset(ChainOfThoughtDataset):
|
|
|
|
+ def process_single_item(self, item, **kwargs):
|
|
|
|
+ item["targets"] = item["answer"].split("####")[1].strip()
|
|
|
|
+ return super().process_single_item(item, **kwargs)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class SportsDataset(ChainOfThoughtDataset):
|
|
|
|
+ def process_single_file(self, path):
|
|
|
|
+ with open(path) as file:
|
|
|
|
+ dataset = json.load(file)
|
|
|
|
+ for item in dataset["examples"]:
|
|
|
|
+ sentence = item["input"]
|
|
|
|
+ item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
|
|
|
|
+ if item["target_scores"]["plausible"] == 1:
|
|
|
|
+ item["targets"] = "yes"
|
|
|
|
+ else:
|
|
|
|
+ item["targets"] = "no"
|
|
|
|
+ self.data.extend(self.process_single_item(item))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class LastLetterDataset(ChainOfThoughtDataset):
|
|
|
|
+ def process_single_item(self, item, **kwargs):
|
|
|
|
+ first_name, last_name = item["first_name"], item["last_name"]
|
|
|
|
+ question = f'Take the last letters of the words in "{first_name} {last_name}" and concatenate them.'
|
|
|
|
+ item["question"] = question
|
|
|
|
+ return super().process_single_item(item, **kwargs)
|
|
|
|
+
|
|
|
|
+
|
|
class ChainOfThoughtTask(GenerationTask):
|
|
class ChainOfThoughtTask(GenerationTask):
|
|
config: ChainOfThoughtConfig
|
|
config: ChainOfThoughtConfig
|
|
|
|
|
|
@@ -94,19 +176,28 @@ class ChainOfThoughtTask(GenerationTask):
|
|
assert len(predictions) == len(examples)
|
|
assert len(predictions) == len(examples)
|
|
for prediction, example in zip(predictions, examples):
|
|
for prediction, example in zip(predictions, examples):
|
|
output = self.tokenizer.detokenize(prediction)
|
|
output = self.tokenizer.detokenize(prediction)
|
|
- prediction = extract_answer(output, self.config.name).strip()
|
|
|
|
|
|
+ prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
|
|
target = self.tokenizer.detokenize(example["targets"]).strip()
|
|
target = self.tokenizer.detokenize(example["targets"]).strip()
|
|
count += prediction == target
|
|
count += prediction == target
|
|
return count * 100.0 / num_predictions
|
|
return count * 100.0 / num_predictions
|
|
|
|
|
|
def build_dataset(self, relative_path, split):
|
|
def build_dataset(self, relative_path, split):
|
|
- return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
|
|
|
+ if self.config.name.startswith("gsm8k"):
|
|
|
|
+ return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
|
+ elif self.config.name.startswith("sports"):
|
|
|
|
+ return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
|
+ elif self.config.name.startswith("lastletter"):
|
|
|
|
+ return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
|
+ elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
|
|
|
|
+ return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
|
+ else:
|
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
def save_prediction_to_file(self, file, predictions, data):
|
|
def save_prediction_to_file(self, file, predictions, data):
|
|
results = []
|
|
results = []
|
|
for output, item in zip(predictions, data):
|
|
for output, item in zip(predictions, data):
|
|
output = self.tokenizer.detokenize(output)
|
|
output = self.tokenizer.detokenize(output)
|
|
- prediction = extract_answer(output, self.config.name)
|
|
|
|
|
|
+ prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
|
|
target = self.tokenizer.detokenize(item["targets"])
|
|
target = self.tokenizer.detokenize(item["targets"])
|
|
results.append({"output": output, "prediction": prediction, "answer": target})
|
|
results.append({"output": output, "prediction": prediction, "answer": target})
|
|
file_name = file.split(".")[0]
|
|
file_name = file.split(".")[0]
|