task.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import os
  2. import json
  3. import re
  4. from typing import Union, List, Dict, Callable
  5. from datetime import datetime
  6. from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
  7. from evaluation.utils import print_rank_0
  8. from dataclasses import dataclass
  9. @dataclass
  10. class ChainOfThoughtConfig(GenerationTaskConfig):
  11. prompt_path: str = None
  12. chain_of_thought: bool = True
  13. prompt_type: str = None
  14. def read_examples(prompt_path):
  15. examples = []
  16. item = {"question": None, "answer": None}
  17. with open(prompt_path) as file:
  18. for line in file:
  19. line = line.strip()
  20. if line.startswith("Q:"):
  21. question = line[3:]
  22. item["question"] = question
  23. elif line.startswith("A:"):
  24. answer = line[3:]
  25. item["answer"] = answer
  26. examples.append(item)
  27. item = {"question": None, "answer": None}
  28. else:
  29. raise NotImplementedError
  30. return examples
  31. def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
  32. prompts = []
  33. for i, item in enumerate(examples):
  34. question, answer = item["question"], item["answer"]
  35. if not chain_of_thought:
  36. answer = extract_answer(answer, task_name)
  37. if prompt_type == "number":
  38. prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
  39. else:
  40. prompts.append(f"Question: {question} Answer: {answer}")
  41. if prompt_type == "return":
  42. prompt = " <n>".join(prompts)
  43. else:
  44. prompt = " ".join(prompts)
  45. return prompt
  46. def extract_answer(prediction, task_name, chain_of_thought=True):
  47. if task_name.startswith("gsm8k"):
  48. prediction = prediction.lower()
  49. if chain_of_thought:
  50. pattern = r'(?<=the answer is )\d+'
  51. else:
  52. pattern = r'\d+'
  53. match = re.search(pattern, prediction)
  54. if match:
  55. answer = match.group(0)
  56. else:
  57. answer = ""
  58. elif task_name.startswith("sports") or task_name.startswith("coinflip"):
  59. prediction = prediction.lower()
  60. if chain_of_thought:
  61. pattern = r'(?<=the answer is )(yes|no)'
  62. else:
  63. pattern = r'yes|no'
  64. match = re.search(pattern, prediction)
  65. if match:
  66. answer = match.group(0)
  67. else:
  68. answer = "no"
  69. elif task_name.startswith("lastletter"):
  70. prediction = prediction.lower()
  71. if chain_of_thought:
  72. pattern = r'(?<=the answer is )[a-z]+'
  73. else:
  74. pattern = r'[a-z]+'
  75. match = re.search(pattern, prediction)
  76. if match:
  77. answer = match.group(0)
  78. else:
  79. answer = ""
  80. elif task_name.startswith("reverse"):
  81. prediction = prediction.lower()
  82. if chain_of_thought:
  83. pattern = r'(?<=the answer is ")[a-z|,| ]+'
  84. else:
  85. pattern = r'[a-z|,| ]+'
  86. match = re.search(pattern, prediction)
  87. if match:
  88. answer = match.group(0)
  89. else:
  90. answer = ""
  91. else:
  92. raise NotImplementedError(task_name)
  93. return answer
  94. class ChainOfThoughtDataset(GenerationTaskDataset):
  95. config: ChainOfThoughtConfig
  96. def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
  97. self.labeled_examples = read_examples(config.prompt_path)
  98. self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
  99. prompt_type=config.prompt_type)
  100. # print_rank_0(self.labeled_prompt)
  101. self.printed_count = 0
  102. super().__init__(path, config)
  103. # print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
  104. def process_single_item(self, item, **kwargs):
  105. question, targets = item["question"], item["targets"]
  106. if self.config.prompt_type == "number":
  107. text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
  108. elif self.config.prompt_type == "return":
  109. text = self.labeled_prompt + f" <n>Question: {question} Answer:"
  110. else:
  111. text = self.labeled_prompt + f" Question: {question} Answer:"
  112. text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
  113. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  114. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  115. text = text[len(text) - text_length: len(text)]
  116. # if self.printed_count < 3:
  117. # print_rank_0(self.tokenizer.detokenize(text))
  118. # self.printed_count += 1
  119. return [{"text": text, "targets": targets, **kwargs}]
  120. class GSM8KDataset(ChainOfThoughtDataset):
  121. def process_single_item(self, item, **kwargs):
  122. item["targets"] = item["answer"].split("####")[1].strip()
  123. return super().process_single_item(item, **kwargs)
  124. class SportsDataset(ChainOfThoughtDataset):
  125. def process_single_file(self, path):
  126. with open(path) as file:
  127. dataset = json.load(file)
  128. for item in dataset["examples"]:
  129. sentence = item["input"]
  130. item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
  131. if item["target_scores"]["plausible"] == 1:
  132. item["targets"] = "yes"
  133. else:
  134. item["targets"] = "no"
  135. self.data.extend(self.process_single_item(item))
  136. class LastLetterDataset(ChainOfThoughtDataset):
  137. def process_single_item(self, item, **kwargs):
  138. first_name, last_name = item["first_name"], item["last_name"]
  139. question = f'Take the last letters of the words in "{first_name} {last_name}" and concatenate them.'
  140. item["question"] = question
  141. return super().process_single_item(item, **kwargs)
  142. class ChainOfThoughtTask(GenerationTask):
  143. config: ChainOfThoughtConfig
  144. @classmethod
  145. def config_class(cls):
  146. return ChainOfThoughtConfig
  147. @property
  148. def metrics(self) -> Dict[str, Callable]:
  149. return {'acuracy': self.extracted_accuracy_metric}
  150. def extracted_accuracy_metric(self, predictions, examples):
  151. count = 0
  152. num_predictions = max(len(predictions), 1)
  153. assert len(predictions) == len(examples)
  154. for prediction, example in zip(predictions, examples):
  155. output = self.tokenizer.detokenize(prediction)
  156. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
  157. target = self.tokenizer.detokenize(example["targets"]).strip()
  158. count += prediction == target
  159. return count * 100.0 / num_predictions
  160. def build_dataset(self, relative_path):
  161. if self.config.name.startswith("gsm8k"):
  162. return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
  163. elif self.config.name.startswith("sports"):
  164. return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
  165. elif self.config.name.startswith("lastletter"):
  166. return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
  167. elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
  168. return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
  169. else:
  170. raise NotImplementedError
  171. def save_prediction_to_file(self, file, predictions, data):
  172. results = []
  173. for output, item in zip(predictions, data):
  174. output = self.tokenizer.detokenize(output)
  175. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
  176. target = self.tokenizer.detokenize(item["targets"])
  177. results.append({"output": output, "prediction": prediction, "answer": target})
  178. file_name = file.split(".")[0]
  179. if not os.path.exists("outputs"):
  180. os.mkdir("outputs")
  181. with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
  182. '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
  183. for result in results:
  184. output.write(json.dumps(result) + "\n")