|
@@ -55,9 +55,9 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
if task_name.startswith("gsm8k"):
|
|
|
prediction = prediction.lower()
|
|
|
if chain_of_thought:
|
|
|
- pattern = r'(?<=the answer is )\d+'
|
|
|
+ pattern = r"(?<=the answer is )\d+"
|
|
|
else:
|
|
|
- pattern = r'\d+'
|
|
|
+ pattern = r"\d+"
|
|
|
match = re.search(pattern, prediction)
|
|
|
if match:
|
|
|
answer = match.group(0)
|
|
@@ -66,9 +66,9 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
elif task_name.startswith("sports") or task_name.startswith("coinflip"):
|
|
|
prediction = prediction.lower()
|
|
|
if chain_of_thought:
|
|
|
- pattern = r'(?<=the answer is )(yes|no)'
|
|
|
+ pattern = r"(?<=the answer is )(yes|no)"
|
|
|
else:
|
|
|
- pattern = r'yes|no'
|
|
|
+ pattern = r"yes|no"
|
|
|
match = re.search(pattern, prediction)
|
|
|
if match:
|
|
|
answer = match.group(0)
|
|
@@ -77,9 +77,9 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
elif task_name.startswith("lastletter"):
|
|
|
prediction = prediction.lower()
|
|
|
if chain_of_thought:
|
|
|
- pattern = r'(?<=the answer is )[a-z]+'
|
|
|
+ pattern = r"(?<=the answer is )[a-z]+"
|
|
|
else:
|
|
|
- pattern = r'[a-z]+'
|
|
|
+ pattern = r"[a-z]+"
|
|
|
match = re.search(pattern, prediction)
|
|
|
if match:
|
|
|
answer = match.group(0)
|
|
@@ -90,7 +90,19 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
if chain_of_thought:
|
|
|
pattern = r'(?<=the answer is ")[a-z|,| ]+'
|
|
|
else:
|
|
|
- pattern = r'[a-z|,| ]+'
|
|
|
+ pattern = r"[a-z|,| ]+"
|
|
|
+ match = re.search(pattern, prediction)
|
|
|
+ if match:
|
|
|
+ answer = match.group(0)
|
|
|
+ else:
|
|
|
+ answer = ""
|
|
|
+ elif task_name.startswith("date"):
|
|
|
+ prediction = prediction.lower()
|
|
|
+ date_regex = r"(((0[0-9])|(1[012]))\/((0[1-9])|([12][0-9])|(3[01]))\/((20[012]\d|19\d\d)|(1\d|2[0123])))"
|
|
|
+ if chain_of_thought:
|
|
|
+ pattern = r"(?<=the answer is )" + date_regex
|
|
|
+ else:
|
|
|
+ pattern = date_regex
|
|
|
match = re.search(pattern, prediction)
|
|
|
if match:
|
|
|
answer = match.group(0)
|
|
@@ -106,8 +118,9 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
|
|
|
def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
|
|
|
self.labeled_examples = read_examples(config.prompt_path)
|
|
|
- self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
|
|
|
- prompt_type=config.prompt_type)
|
|
|
+ self.labeled_prompt = build_prompt(
|
|
|
+ self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought, prompt_type=config.prompt_type
|
|
|
+ )
|
|
|
# print_rank_0(self.labeled_prompt)
|
|
|
self.printed_count = 0
|
|
|
super().__init__(path, config)
|
|
@@ -124,7 +137,7 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
|
|
|
if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
|
text_length = self.config.max_seq_length - self.config.max_gen_length - 2
|
|
|
- text = text[len(text) - text_length: len(text)]
|
|
|
+ text = text[len(text) - text_length : len(text)]
|
|
|
# if self.printed_count < 3:
|
|
|
# print_rank_0(self.tokenizer.detokenize(text))
|
|
|
# self.printed_count += 1
|
|
@@ -143,7 +156,7 @@ class SportsDataset(ChainOfThoughtDataset):
|
|
|
dataset = json.load(file)
|
|
|
for item in dataset["examples"]:
|
|
|
sentence = item["input"]
|
|
|
- item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
|
|
|
+ item["question"] = f'Is the following sentence plausible? "{sentence}."'
|
|
|
if item["target_scores"]["plausible"] == 1:
|
|
|
item["targets"] = "yes"
|
|
|
else:
|
|
@@ -151,6 +164,19 @@ class SportsDataset(ChainOfThoughtDataset):
|
|
|
self.data.extend(self.process_single_item(item))
|
|
|
|
|
|
|
|
|
+class DateDataset(ChainOfThoughtDataset):
|
|
|
+ def process_single_file(self, path):
|
|
|
+ with open(path) as file:
|
|
|
+ dataset = json.load(file)
|
|
|
+ for item in dataset["examples"]:
|
|
|
+ sentence = item["input"]
|
|
|
+ item["question"] = sentence
|
|
|
+ for key, value in item["target_scores"].items():
|
|
|
+ if value == 1:
|
|
|
+ item["targets"] = key
|
|
|
+ self.data.extend(self.process_single_item(item))
|
|
|
+
|
|
|
+
|
|
|
class LastLetterDataset(ChainOfThoughtDataset):
|
|
|
def process_single_item(self, item, **kwargs):
|
|
|
first_name, last_name = item["first_name"], item["last_name"]
|
|
@@ -168,7 +194,7 @@ class ChainOfThoughtTask(GenerationTask):
|
|
|
|
|
|
@property
|
|
|
def metrics(self) -> Dict[str, Callable]:
|
|
|
- return {'acuracy': self.extracted_accuracy_metric}
|
|
|
+ return {"Accuracy": self.extracted_accuracy_metric}
|
|
|
|
|
|
def extracted_accuracy_metric(self, predictions, examples):
|
|
|
count = 0
|
|
@@ -190,6 +216,8 @@ class ChainOfThoughtTask(GenerationTask):
|
|
|
return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
|
|
|
return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
+ elif self.config.name.startswith("date"):
|
|
|
+ return DateDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
@@ -203,7 +231,8 @@ class ChainOfThoughtTask(GenerationTask):
|
|
|
file_name = file.split(".")[0]
|
|
|
if not os.path.exists("outputs"):
|
|
|
os.mkdir("outputs")
|
|
|
- with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
|
|
|
- '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
|
|
|
+ with open(
|
|
|
+ "outputs/" + self.config.name + "_" + datetime.now().strftime("%m-%d-%H-%M_") + file_name + ".json", "w"
|
|
|
+ ) as output:
|
|
|
for result in results:
|
|
|
output.write(json.dumps(result) + "\n")
|