task.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import json
  3. import re
  4. from typing import Union, List, Dict, Callable
  5. from datetime import datetime
  6. from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
  7. from evaluation.utils import print_rank_0
  8. from dataclasses import dataclass
  9. @dataclass
  10. class ChainOfThoughtConfig(GenerationTaskConfig):
  11. prompt_path: str = None
  12. chain_of_thought: bool = True
  13. prompt_type: str = None
  14. def read_examples(prompt_path):
  15. examples = []
  16. item = {"question": None, "answer": None}
  17. with open(prompt_path) as file:
  18. for line in file:
  19. line = line.strip()
  20. if line.startswith("Q:"):
  21. question = line[3:]
  22. item["question"] = question
  23. elif line.startswith("A:"):
  24. answer = line[3:]
  25. item["answer"] = answer
  26. examples.append(item)
  27. item = {"question": None, "answer": None}
  28. else:
  29. raise NotImplementedError
  30. return examples
  31. def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
  32. prompts = []
  33. for i, item in enumerate(examples):
  34. question, answer = item["question"], item["answer"]
  35. if not chain_of_thought:
  36. answer = extract_answer(answer, task_name)
  37. if prompt_type == "number":
  38. prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
  39. else:
  40. prompts.append(f"Question: {question} Answer: {answer}")
  41. if prompt_type == "return":
  42. prompt = " <n>".join(prompts)
  43. else:
  44. prompt = " ".join(prompts)
  45. return prompt
  46. def extract_answer(prediction, task_name, chain_of_thought=True):
  47. if task_name.startswith("gsm8k"):
  48. prediction = prediction.lower()
  49. if chain_of_thought:
  50. pattern = r'(?<=the answer is )\d+'
  51. else:
  52. pattern = r'\d+'
  53. match = re.search(pattern, prediction)
  54. if match:
  55. answer = match.group(0)
  56. else:
  57. answer = ""
  58. elif task_name.startswith("sports"):
  59. prediction = prediction.lower()
  60. if chain_of_thought:
  61. pattern = r'(?<=the answer is )(yes|no)'
  62. else:
  63. pattern = r'yes|no'
  64. match = re.search(pattern, prediction)
  65. if match:
  66. answer = match.group(0)
  67. else:
  68. answer = "no"
  69. else:
  70. raise NotImplementedError(task_name)
  71. return answer
  72. class ChainOfThoughtDataset(GenerationTaskDataset):
  73. config: ChainOfThoughtConfig
  74. def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
  75. self.labeled_examples = read_examples(config.prompt_path)
  76. self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
  77. prompt_type=config.prompt_type)
  78. print_rank_0(self.labeled_prompt)
  79. self.printed_count = 0
  80. super().__init__(path, config)
  81. def process_single_item(self, item, **kwargs):
  82. question, targets = item["question"], item["targets"]
  83. if self.config.prompt_type == "number":
  84. text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
  85. elif self.config.prompt_type == "return":
  86. text = self.labeled_prompt + f" <n>Question: {question} Answer:"
  87. else:
  88. text = self.labeled_prompt + f" Question: {question} Answer:"
  89. text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
  90. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  91. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  92. text = text[len(text) - text_length: len(text)]
  93. if self.printed_count < 3:
  94. print_rank_0(self.tokenizer.detokenize(text))
  95. self.printed_count += 1
  96. return [{"text": text, "targets": targets, **kwargs}]
  97. class GSM8KDataset(ChainOfThoughtDataset):
  98. def process_single_item(self, item, **kwargs):
  99. item["targets"] = item["answer"].split("####")[1].strip()
  100. return super().process_single_item(item)
  101. class SportsDataset(ChainOfThoughtDataset):
  102. def process_single_file(self, path):
  103. with open(path) as file:
  104. dataset = json.load(file)
  105. for item in dataset["examples"]:
  106. sentence = item["input"]
  107. item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
  108. if item["target_scores"]["plausible"] == 1:
  109. item["targets"] = "yes"
  110. else:
  111. item["targets"] = "no"
  112. self.data.extend(self.process_single_item(item))
  113. class ChainOfThoughtTask(GenerationTask):
  114. config: ChainOfThoughtConfig
  115. @classmethod
  116. def config_class(cls):
  117. return ChainOfThoughtConfig
  118. @property
  119. def metrics(self) -> Dict[str, Callable]:
  120. return {'acuracy': self.extracted_accuracy_metric}
  121. def extracted_accuracy_metric(self, predictions, examples):
  122. count = 0
  123. num_predictions = max(len(predictions), 1)
  124. assert len(predictions) == len(examples)
  125. for prediction, example in zip(predictions, examples):
  126. output = self.tokenizer.detokenize(prediction)
  127. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
  128. target = self.tokenizer.detokenize(example["targets"]).strip()
  129. count += prediction == target
  130. return count * 100.0 / num_predictions
  131. def build_dataset(self, relative_path, split):
  132. if self.config.name.startswith("gsm8k"):
  133. return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
  134. elif self.config.name.startswith("sports"):
  135. return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
  136. else:
  137. raise NotImplementedError
  138. def save_prediction_to_file(self, file, predictions, data):
  139. results = []
  140. for output, item in zip(predictions, data):
  141. output = self.tokenizer.detokenize(output)
  142. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
  143. target = self.tokenizer.detokenize(item["targets"])
  144. results.append({"output": output, "prediction": prediction, "answer": target})
  145. file_name = file.split(".")[0]
  146. if not os.path.exists("outputs"):
  147. os.mkdir("outputs")
  148. with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
  149. '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
  150. for result in results:
  151. output.write(json.dumps(result) + "\n")