task.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. import json
  3. import re
  4. from typing import Union, List, Dict, Callable
  5. from datetime import datetime
  6. from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
  7. from evaluation.utils import print_rank_0
  8. from dataclasses import dataclass
  9. @dataclass
  10. class ChainOfThoughtConfig(GenerationTaskConfig):
  11. prompt_path: str = None
  12. def read_examples(prompt_path):
  13. examples = []
  14. item = {"question": None, "answer": None}
  15. with open(prompt_path) as file:
  16. for line in file:
  17. line = line.strip()
  18. if line.startswith("Q:"):
  19. question = line[3:]
  20. item["question"] = question
  21. elif line.startswith("A:"):
  22. answer = line[3:]
  23. item["answer"] = answer
  24. examples.append(item)
  25. item = {"question": None, "answer": None}
  26. else:
  27. raise NotImplementedError
  28. return examples
  29. def build_prompt(examples):
  30. prompts = []
  31. for item in examples:
  32. question, answer = item["question"], item["answer"]
  33. prompts.append(f"Question: {question} Answer: {answer}")
  34. prompt = " ".join(prompts)
  35. return prompt
  36. def extract_answer(prediction, task_name):
  37. if task_name == "gsm8k":
  38. prediction = prediction.lower()
  39. match = re.search(r'(?<=the answer is )\d+', prediction)
  40. if match:
  41. answer = match.group(0)
  42. else:
  43. answer = ""
  44. else:
  45. raise NotImplementedError(task_name)
  46. return answer
  47. class ChainOfThoughtDataset(GenerationTaskDataset):
  48. def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
  49. self.labeled_examples = read_examples(config.prompt_path)
  50. self.labeled_prompt = build_prompt(self.labeled_examples)
  51. print_rank_0(self.labeled_prompt)
  52. self.printed_count = 0
  53. super().__init__(path, config)
  54. def process_single_item(self, item, **kwargs):
  55. question = item["question"]
  56. targets = item["answer"].split("####")[1].strip()
  57. text = self.labeled_prompt + f" Question: {question} Answer:"
  58. text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
  59. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  60. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  61. text = text[len(text) - text_length: len(text)]
  62. if self.printed_count < 3:
  63. print_rank_0(self.tokenizer.detokenize(text))
  64. self.printed_count += 1
  65. return [{"text": text, "targets": targets, **kwargs}]
  66. class ChainOfThoughtTask(GenerationTask):
  67. config: ChainOfThoughtConfig
  68. @classmethod
  69. def config_class(cls):
  70. return ChainOfThoughtConfig
  71. @property
  72. def metrics(self) -> Dict[str, Callable]:
  73. return {'acuracy': self.extracted_accuracy_metric}
  74. def extracted_accuracy_metric(self, predictions, examples):
  75. count = 0
  76. num_predictions = max(len(predictions), 1)
  77. assert len(predictions) == len(examples)
  78. for prediction, example in zip(predictions, examples):
  79. output = self.tokenizer.detokenize(prediction)
  80. prediction = extract_answer(output, self.config.name).strip()
  81. target = self.tokenizer.detokenize(example["targets"]).strip()
  82. count += prediction == target
  83. return count * 100.0 / num_predictions
  84. def build_dataset(self, relative_path, split):
  85. return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
  86. def save_prediction_to_file(self, file, predictions, data):
  87. results = []
  88. for output, item in zip(predictions, data):
  89. output = self.tokenizer.detokenize(output)
  90. prediction = extract_answer(output, self.config.name)
  91. target = self.tokenizer.detokenize(item["targets"])
  92. results.append({"output": output, "prediction": prediction, "answer": target})
  93. file_name = file.split(".")[0]
  94. if not os.path.exists("outputs"):
  95. os.mkdir("outputs")
  96. with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
  97. '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
  98. for result in results:
  99. output.write(json.dumps(result) + "\n")