task.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import os
  2. import json
  3. import re
  4. from typing import Union, List, Dict, Callable
  5. from datetime import datetime
  6. from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
  7. from evaluation.utils import print_rank_0
  8. from dataclasses import dataclass
  9. @dataclass
  10. class ChainOfThoughtConfig(GenerationTaskConfig):
  11. prompt_path: str = None
  12. chain_of_thought: bool = True
  13. prompt_type: str = None
  14. def read_examples(prompt_path):
  15. examples = []
  16. item = {"question": None, "answer": None}
  17. with open(prompt_path) as file:
  18. for line in file:
  19. line = line.strip()
  20. if line.startswith("Q:"):
  21. question = line[3:]
  22. item["question"] = question
  23. elif line.startswith("A:"):
  24. answer = line[3:]
  25. item["answer"] = answer
  26. examples.append(item)
  27. item = {"question": None, "answer": None}
  28. else:
  29. raise NotImplementedError
  30. return examples
  31. def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
  32. prompts = []
  33. for i, item in enumerate(examples):
  34. question, answer = item["question"], item["answer"]
  35. if not chain_of_thought:
  36. answer = extract_answer(answer, task_name)
  37. if prompt_type == "number":
  38. prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
  39. else:
  40. prompts.append(f"Question: {question} Answer: {answer}")
  41. if prompt_type == "return":
  42. prompt = " <n>".join(prompts)
  43. else:
  44. prompt = " ".join(prompts)
  45. return prompt
  46. def extract_answer(prediction, task_name, chain_of_thought=True):
  47. if task_name.startswith("gsm8k"):
  48. prediction = prediction.lower()
  49. if chain_of_thought:
  50. pattern = r"(?<=the answer is )\d+"
  51. else:
  52. pattern = r"\d+"
  53. match = re.search(pattern, prediction)
  54. if match:
  55. answer = match.group(0)
  56. else:
  57. answer = ""
  58. elif task_name.startswith("sports") or task_name.startswith("coinflip"):
  59. prediction = prediction.lower()
  60. if chain_of_thought:
  61. pattern = r"(?<=the answer is )(yes|no)"
  62. else:
  63. pattern = r"yes|no"
  64. match = re.search(pattern, prediction)
  65. if match:
  66. answer = match.group(0)
  67. else:
  68. answer = "no"
  69. elif task_name.startswith("lastletter"):
  70. prediction = prediction.lower()
  71. if chain_of_thought:
  72. pattern = r"(?<=the answer is )[a-z]+"
  73. else:
  74. pattern = r"[a-z]+"
  75. match = re.search(pattern, prediction)
  76. if match:
  77. answer = match.group(0)
  78. else:
  79. answer = ""
  80. elif task_name.startswith("reverse"):
  81. prediction = prediction.lower()
  82. if chain_of_thought:
  83. pattern = r'(?<=the answer is ")[a-z|,| ]+'
  84. else:
  85. pattern = r"[a-z|,| ]+"
  86. match = re.search(pattern, prediction)
  87. if match:
  88. answer = match.group(0)
  89. else:
  90. answer = ""
  91. elif task_name.startswith("date"):
  92. prediction = prediction.lower()
  93. date_regex = r"(((0[0-9])|(1[012]))\/((0[1-9])|([12][0-9])|(3[01]))\/((20[012]\d|19\d\d)|(1\d|2[0123])))"
  94. if chain_of_thought:
  95. pattern = r"(?<=the answer is )" + date_regex
  96. else:
  97. pattern = date_regex
  98. match = re.search(pattern, prediction)
  99. if match:
  100. answer = match.group(0)
  101. else:
  102. answer = ""
  103. else:
  104. raise NotImplementedError(task_name)
  105. return answer
  106. class ChainOfThoughtDataset(GenerationTaskDataset):
  107. config: ChainOfThoughtConfig
  108. def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
  109. self.labeled_examples = read_examples(config.prompt_path)
  110. self.labeled_prompt = build_prompt(
  111. self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought, prompt_type=config.prompt_type
  112. )
  113. # print_rank_0(self.labeled_prompt)
  114. self.printed_count = 0
  115. super().__init__(path, config)
  116. # print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
  117. def process_single_item(self, item, **kwargs):
  118. question, targets = item["question"], item["targets"]
  119. if self.config.prompt_type == "number":
  120. text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
  121. elif self.config.prompt_type == "return":
  122. text = self.labeled_prompt + f" <n>Question: {question} Answer:"
  123. else:
  124. text = self.labeled_prompt + f" Question: {question} Answer:"
  125. text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
  126. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  127. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  128. text = text[len(text) - text_length : len(text)]
  129. # if self.printed_count < 3:
  130. # print_rank_0(self.tokenizer.detokenize(text))
  131. # self.printed_count += 1
  132. return [{"text": text, "targets": targets, **kwargs}]
  133. class GSM8KDataset(ChainOfThoughtDataset):
  134. def process_single_item(self, item, **kwargs):
  135. item["targets"] = item["answer"].split("####")[1].strip()
  136. return super().process_single_item(item, **kwargs)
  137. class SportsDataset(ChainOfThoughtDataset):
  138. def process_single_file(self, path):
  139. with open(path) as file:
  140. dataset = json.load(file)
  141. for item in dataset["examples"]:
  142. sentence = item["input"]
  143. item["question"] = f'Is the following sentence plausible? "{sentence}."'
  144. if item["target_scores"]["plausible"] == 1:
  145. item["targets"] = "yes"
  146. else:
  147. item["targets"] = "no"
  148. self.data.extend(self.process_single_item(item))
  149. class DateDataset(ChainOfThoughtDataset):
  150. def process_single_file(self, path):
  151. with open(path) as file:
  152. dataset = json.load(file)
  153. for item in dataset["examples"]:
  154. sentence = item["input"]
  155. item["question"] = sentence
  156. for key, value in item["target_scores"].items():
  157. if value == 1:
  158. item["targets"] = key
  159. self.data.extend(self.process_single_item(item))
  160. class LastLetterDataset(ChainOfThoughtDataset):
  161. def process_single_item(self, item, **kwargs):
  162. first_name, last_name = item["first_name"], item["last_name"]
  163. question = f'Take the last letters of the words in "{first_name} {last_name}" and concatenate them.'
  164. item["question"] = question
  165. return super().process_single_item(item, **kwargs)
  166. class ChainOfThoughtTask(GenerationTask):
  167. config: ChainOfThoughtConfig
  168. @classmethod
  169. def config_class(cls):
  170. return ChainOfThoughtConfig
  171. @property
  172. def metrics(self) -> Dict[str, Callable]:
  173. return {"Accuracy": self.extracted_accuracy_metric}
  174. def extracted_accuracy_metric(self, predictions, examples):
  175. count = 0
  176. num_predictions = max(len(predictions), 1)
  177. assert len(predictions) == len(examples)
  178. for prediction, example in zip(predictions, examples):
  179. output = self.tokenizer.detokenize(prediction)
  180. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
  181. target = self.tokenizer.detokenize(example["targets"]).strip()
  182. count += prediction == target
  183. return count * 100.0 / num_predictions
  184. def build_dataset(self, relative_path):
  185. if self.config.name.startswith("gsm8k"):
  186. return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
  187. elif self.config.name.startswith("sports"):
  188. return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
  189. elif self.config.name.startswith("lastletter"):
  190. return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
  191. elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
  192. return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
  193. elif self.config.name.startswith("date"):
  194. return DateDataset(os.path.join(self.config.path, relative_path), self.config)
  195. else:
  196. raise NotImplementedError
  197. def save_prediction_to_file(self, file, predictions, data):
  198. results = []
  199. for output, item in zip(predictions, data):
  200. output = self.tokenizer.detokenize(output)
  201. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
  202. target = self.tokenizer.detokenize(item["targets"])
  203. results.append({"output": output, "prediction": prediction, "answer": target})
  204. file_name = file.split(".")[0]
  205. if not os.path.exists("outputs"):
  206. os.mkdir("outputs")
  207. with open(
  208. "outputs/" + self.config.name + "_" + datetime.now().strftime("%m-%d-%H-%M_") + file_name + ".json", "w"
  209. ) as output:
  210. for result in results:
  211. output.write(json.dumps(result) + "\n")