task.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import os
  2. import re
  3. from datetime import datetime
  4. from functools import partial
  5. from evaluation import qa_evaluate, GenerationTask
  6. def extract_answer(prediction):
  7. pattern = r"(?<=(the|The) answer is ).*?(?=\.\n)"
  8. match = re.search(pattern, prediction)
  9. if match:
  10. answer = match.group(0)
  11. else:
  12. answer = ""
  13. return answer
  14. def exact_match_score(prediction, ground_truth):
  15. return extract_answer(prediction) == ground_truth
  16. class BBHGeneration(GenerationTask):
  17. @property
  18. def metrics(self):
  19. return {"Accuracy": partial(qa_evaluate, metric=exact_match_score)}
  20. def __init__(self, model, tokenizer, config):
  21. super(BBHGeneration, self).__init__(model, tokenizer, config)
  22. self.start_time = datetime.now()
  23. def save_prediction_to_file(self, file, prediction, data):
  24. filename = os.path.join(f"outputs_{self.start_time}", self.config.name, f"{file}.predict")
  25. os.makedirs(os.path.dirname(filename), exist_ok=True)
  26. with open(filename, "w") as file:
  27. for item in prediction:
  28. file.write(str([self.tokenizer.detokenize(item)]) + "\n")