2
0

task.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import os
  2. import json
  3. import re
  4. from typing import Union, List, Dict, Callable
  5. from datetime import datetime
  6. from evaluation.model import ModelForEvaluation
  7. from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
  8. from evaluation.utils import print_rank_0
  9. from dataclasses import dataclass
  10. @dataclass
  11. class ChainOfThoughtConfig(GenerationTaskConfig):
  12. prompt_path: str = None
  13. chain_of_thought: bool = True
  14. prompt_type: str = None
  15. def read_examples(prompt_path):
  16. examples = []
  17. item = {"question": None, "answer": None}
  18. with open(prompt_path) as file:
  19. for line in file:
  20. line = line.strip()
  21. if line.startswith("Q:"):
  22. question = line[3:]
  23. item["question"] = question
  24. elif line.startswith("A:"):
  25. answer = line[3:]
  26. item["answer"] = answer
  27. examples.append(item)
  28. item = {"question": None, "answer": None}
  29. else:
  30. raise NotImplementedError
  31. return examples
  32. def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
  33. prompts = []
  34. for i, item in enumerate(examples):
  35. question, answer = item["question"], item["answer"]
  36. if not chain_of_thought:
  37. answer = extract_answer(answer, task_name)
  38. if prompt_type == "number":
  39. prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
  40. else:
  41. prompts.append(f"Question: {question} Answer: {answer}")
  42. if prompt_type == "return":
  43. prompt = " <n>".join(prompts)
  44. else:
  45. prompt = " ".join(prompts)
  46. return prompt
  47. def extract_answer(prediction, task_name, chain_of_thought=True):
  48. if task_name.startswith("gsm8k"):
  49. prediction = prediction.lower()
  50. if chain_of_thought:
  51. pattern = r"(?<=the answer is )\d+"
  52. else:
  53. pattern = r"\d+"
  54. match = re.search(pattern, prediction)
  55. if match:
  56. answer = match.group(0)
  57. else:
  58. answer = ""
  59. elif task_name.startswith("sports") or task_name.startswith("coinflip"):
  60. prediction = prediction.lower()
  61. if chain_of_thought:
  62. pattern = r"(?<=the answer is )(yes|no)"
  63. else:
  64. pattern = r"yes|no"
  65. match = re.search(pattern, prediction)
  66. if match:
  67. answer = match.group(0)
  68. else:
  69. answer = "no"
  70. elif task_name.startswith("lastletter"):
  71. prediction = prediction.lower()
  72. if chain_of_thought:
  73. pattern = r"(?<=the answer is )[a-z]+"
  74. else:
  75. pattern = r"[a-z]+"
  76. match = re.search(pattern, prediction)
  77. if match:
  78. answer = match.group(0)
  79. else:
  80. answer = ""
  81. elif task_name.startswith("reverse"):
  82. prediction = prediction.lower()
  83. if chain_of_thought:
  84. pattern = r'(?<=the answer is ")[a-z|,| ]+'
  85. else:
  86. pattern = r"[a-z|,| ]+"
  87. match = re.search(pattern, prediction)
  88. if match:
  89. answer = match.group(0)
  90. else:
  91. answer = ""
  92. elif task_name.startswith("date"):
  93. prediction = prediction.lower()
  94. date_regex = r"(((0[0-9])|(1[012]))\/((0[1-9])|([12][0-9])|(3[01]))\/((20[012]\d|19\d\d)|(1\d|2[0123])))"
  95. if chain_of_thought:
  96. pattern = r"(?<=the answer is )" + date_regex
  97. else:
  98. pattern = date_regex
  99. match = re.search(pattern, prediction)
  100. if match:
  101. answer = match.group(0)
  102. else:
  103. answer = ""
  104. else:
  105. raise NotImplementedError(task_name)
  106. return answer
  107. class ChainOfThoughtDataset(GenerationTaskDataset):
  108. config: ChainOfThoughtConfig
  109. def __init__(self, path: Union[str, List[str]], model: ModelForEvaluation, config: ChainOfThoughtConfig):
  110. self.labeled_examples = read_examples(config.prompt_path)
  111. self.labeled_prompt = build_prompt(
  112. self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought, prompt_type=config.prompt_type
  113. )
  114. # print_rank_0(self.labeled_prompt)
  115. self.printed_count = 0
  116. super().__init__(path, model, config)
  117. # print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
  118. def process_single_item(self, item, **kwargs):
  119. question, targets = item["question"], item["targets"]
  120. if self.config.prompt_type == "number":
  121. text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
  122. elif self.config.prompt_type == "return":
  123. text = self.labeled_prompt + f" <n>Question: {question} Answer:"
  124. else:
  125. text = self.labeled_prompt + f" Question: {question} Answer:"
  126. text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
  127. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  128. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  129. text = text[len(text) - text_length : len(text)]
  130. # if self.printed_count < 3:
  131. # print_rank_0(self.tokenizer.detokenize(text))
  132. # self.printed_count += 1
  133. return [{"text": text, "targets": targets, **kwargs}]
  134. class GSM8KDataset(ChainOfThoughtDataset):
  135. def process_single_item(self, item, **kwargs):
  136. item["targets"] = item["answer"].split("####")[1].strip()
  137. return super().process_single_item(item, **kwargs)
  138. class SportsDataset(ChainOfThoughtDataset):
  139. def process_single_file(self, path):
  140. with open(path) as file:
  141. dataset = json.load(file)
  142. for item in dataset["examples"]:
  143. sentence = item["input"]
  144. item["question"] = f'Is the following sentence plausible? "{sentence}."'
  145. if item["target_scores"]["plausible"] == 1:
  146. item["targets"] = "yes"
  147. else:
  148. item["targets"] = "no"
  149. self.data.extend(self.process_single_item(item))
  150. class DateDataset(ChainOfThoughtDataset):
  151. def process_single_file(self, path):
  152. with open(path) as file:
  153. dataset = json.load(file)
  154. for item in dataset["examples"]:
  155. sentence = item["input"]
  156. item["question"] = sentence
  157. for key, value in item["target_scores"].items():
  158. if value == 1:
  159. item["targets"] = key
  160. self.data.extend(self.process_single_item(item))
  161. class LastLetterDataset(ChainOfThoughtDataset):
  162. def process_single_item(self, item, **kwargs):
  163. first_name, last_name = item["first_name"], item["last_name"]
  164. question = f'Take the last letters of the words in "{first_name} {last_name}" and concatenate them.'
  165. item["question"] = question
  166. return super().process_single_item(item, **kwargs)
  167. class ChainOfThoughtTask(GenerationTask):
  168. config: ChainOfThoughtConfig
  169. @classmethod
  170. def config_class(cls):
  171. return ChainOfThoughtConfig
  172. @property
  173. def metrics(self) -> Dict[str, Callable]:
  174. return {"Accuracy": self.extracted_accuracy_metric}
  175. def extracted_accuracy_metric(self, predictions, examples):
  176. count = 0
  177. num_predictions = max(len(predictions), 1)
  178. assert len(predictions) == len(examples)
  179. for prediction, example in zip(predictions, examples):
  180. output = self.tokenizer.detokenize(prediction)
  181. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
  182. target = self.tokenizer.detokenize(example["targets"]).strip()
  183. count += prediction == target
  184. return count * 100.0 / num_predictions
  185. def build_dataset(self, relative_path):
  186. if self.config.name.startswith("gsm8k"):
  187. return GSM8KDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
  188. elif self.config.name.startswith("sports"):
  189. return SportsDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
  190. elif self.config.name.startswith("lastletter"):
  191. return LastLetterDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
  192. elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
  193. return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
  194. elif self.config.name.startswith("date"):
  195. return DateDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
  196. else:
  197. raise NotImplementedError
  198. def save_prediction_to_file(self, file, predictions, data):
  199. results = []
  200. for output, item in zip(predictions, data):
  201. output = self.tokenizer.detokenize(output)
  202. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
  203. target = self.tokenizer.detokenize(item["targets"])
  204. results.append({"output": output, "prediction": prediction, "answer": target})
  205. file_name = file.split(".")[0]
  206. if not os.path.exists("outputs"):
  207. os.mkdir("outputs")
  208. with open(
  209. "outputs/" + self.config.name + "_" + datetime.now().strftime("%m-%d-%H-%M_") + file_name + ".json", "w"
  210. ) as output:
  211. for result in results:
  212. output.write(json.dumps(result) + "\n")