2
0

task.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import os
  2. import json
  3. import re
  4. from typing import Union, List, Dict, Callable
  5. from datetime import datetime
  6. from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
  7. from evaluation.utils import print_rank_0
  8. from dataclasses import dataclass
  9. @dataclass
  10. class ChainOfThoughtConfig(GenerationTaskConfig):
  11. prompt_path: str = None
  12. chain_of_thought: bool = True
  13. def read_examples(prompt_path):
  14. examples = []
  15. item = {"question": None, "answer": None}
  16. with open(prompt_path) as file:
  17. for line in file:
  18. line = line.strip()
  19. if line.startswith("Q:"):
  20. question = line[3:]
  21. item["question"] = question
  22. elif line.startswith("A:"):
  23. answer = line[3:]
  24. item["answer"] = answer
  25. examples.append(item)
  26. item = {"question": None, "answer": None}
  27. else:
  28. raise NotImplementedError
  29. return examples
  30. def build_prompt(examples, task_name, chain_of_thought=True):
  31. prompts = []
  32. for item in examples:
  33. question, answer = item["question"], item["answer"]
  34. if not chain_of_thought:
  35. answer = extract_answer(answer, task_name)
  36. prompts.append(f"Question: {question} Answer: {answer}")
  37. prompt = " ".join(prompts)
  38. return prompt
  39. def extract_answer(prediction, task_name, chain_of_thought=True):
  40. if task_name.startswith("gsm8k"):
  41. prediction = prediction.lower()
  42. if chain_of_thought:
  43. pattern = r'(?<=the answer is )\d+'
  44. else:
  45. pattern = r'\d+'
  46. match = re.search(pattern, prediction)
  47. if match:
  48. answer = match.group(0)
  49. else:
  50. answer = ""
  51. elif task_name.startswith("sports"):
  52. prediction = prediction.lower()
  53. if chain_of_thought:
  54. pattern = r'(?<=the answer is )(yes|no)'
  55. else:
  56. pattern = r'yes|no'
  57. match = re.search(pattern, prediction)
  58. if match:
  59. answer = match.group(0)
  60. else:
  61. answer = "no"
  62. else:
  63. raise NotImplementedError(task_name)
  64. return answer
  65. class ChainOfThoughtDataset(GenerationTaskDataset):
  66. def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
  67. self.labeled_examples = read_examples(config.prompt_path)
  68. self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought)
  69. print_rank_0(self.labeled_prompt)
  70. self.printed_count = 0
  71. super().__init__(path, config)
  72. def process_single_item(self, item, **kwargs):
  73. question, targets = item["question"], item["targets"]
  74. text = self.labeled_prompt + f" Question: {question} Answer:"
  75. text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
  76. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  77. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  78. text = text[len(text) - text_length: len(text)]
  79. if self.printed_count < 3:
  80. print_rank_0(self.tokenizer.detokenize(text))
  81. self.printed_count += 1
  82. return [{"text": text, "targets": targets, **kwargs}]
  83. class GSM8KDataset(ChainOfThoughtDataset):
  84. def process_single_item(self, item, **kwargs):
  85. item["targets"] = item["answer"].split("####")[1].strip()
  86. return super().process_single_item(item)
  87. class SportsDataset(ChainOfThoughtDataset):
  88. def process_single_file(self, path):
  89. with open(path) as file:
  90. dataset = json.load(file)
  91. for item in dataset["examples"]:
  92. sentence = item["input"]
  93. item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
  94. if item["target_scores"]["plausible"] == 1:
  95. item["targets"] = "yes"
  96. else:
  97. item["targets"] = "no"
  98. self.data.extend(self.process_single_item(item))
  99. class ChainOfThoughtTask(GenerationTask):
  100. config: ChainOfThoughtConfig
  101. @classmethod
  102. def config_class(cls):
  103. return ChainOfThoughtConfig
  104. @property
  105. def metrics(self) -> Dict[str, Callable]:
  106. return {'acuracy': self.extracted_accuracy_metric}
  107. def extracted_accuracy_metric(self, predictions, examples):
  108. count = 0
  109. num_predictions = max(len(predictions), 1)
  110. assert len(predictions) == len(examples)
  111. for prediction, example in zip(predictions, examples):
  112. output = self.tokenizer.detokenize(prediction)
  113. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
  114. target = self.tokenizer.detokenize(example["targets"]).strip()
  115. count += prediction == target
  116. return count * 100.0 / num_predictions
  117. def build_dataset(self, relative_path, split):
  118. if self.config.name.startswith("gsm8k"):
  119. return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
  120. elif self.config.name.startswith("sports"):
  121. return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
  122. else:
  123. raise NotImplementedError
  124. def save_prediction_to_file(self, file, predictions, data):
  125. results = []
  126. for output, item in zip(predictions, data):
  127. output = self.tokenizer.detokenize(output)
  128. prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
  129. target = self.tokenizer.detokenize(item["targets"])
  130. results.append({"output": output, "prediction": prediction, "answer": target})
  131. file_name = file.split(".")[0]
  132. if not os.path.exists("outputs"):
  133. os.mkdir("outputs")
  134. with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
  135. '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
  136. for result in results:
  137. output.write(json.dumps(result) + "\n")