|
@@ -74,6 +74,17 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
answer = match.group(0)
|
|
|
else:
|
|
|
answer = "no"
|
|
|
+ elif task_name.startswith("lastletter"):
|
|
|
+ prediction = prediction.lower()
|
|
|
+ if chain_of_thought:
|
|
|
+ pattern = r'(?<=the answer is )[a-z]+'
|
|
|
+ else:
|
|
|
+ pattern = r'[a-z]+'
|
|
|
+ match = re.search(pattern, prediction)
|
|
|
+ if match:
|
|
|
+ answer = match.group(0)
|
|
|
+ else:
|
|
|
+ answer = ""
|
|
|
else:
|
|
|
raise NotImplementedError(task_name)
|
|
|
return answer
|
|
@@ -111,7 +122,7 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
class GSM8KDataset(ChainOfThoughtDataset):
|
|
|
def process_single_item(self, item, **kwargs):
|
|
|
item["targets"] = item["answer"].split("####")[1].strip()
|
|
|
- return super().process_single_item(item)
|
|
|
+ return super().process_single_item(item, **kwargs)
|
|
|
|
|
|
|
|
|
class SportsDataset(ChainOfThoughtDataset):
|
|
@@ -128,6 +139,14 @@ class SportsDataset(ChainOfThoughtDataset):
|
|
|
self.data.extend(self.process_single_item(item))
|
|
|
|
|
|
|
|
|
+class LastLetterDataset(ChainOfThoughtDataset):
|
|
|
+ def process_single_item(self, item, **kwargs):
|
|
|
+ first_name, last_name = item["first_name"], item["last_name"]
|
|
|
+ question = f'Take the last letters of the words in "{first_name} {last_name}" and concatenate them.'
|
|
|
+ item["question"] = question
|
|
|
+ return super().process_single_item(item, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
class ChainOfThoughtTask(GenerationTask):
|
|
|
config: ChainOfThoughtConfig
|
|
|
|
|
@@ -155,6 +174,8 @@ class ChainOfThoughtTask(GenerationTask):
|
|
|
return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
elif self.config.name.startswith("sports"):
|
|
|
return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
+ elif self.config.name.startswith("lastletter"):
|
|
|
+ return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|