metric.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import tqdm
  2. import numpy as np
  3. from collections import defaultdict, Counter
  4. from concurrent.futures import ThreadPoolExecutor, as_completed
  5. from evaluation.utils import print_rank_0
  6. from .human_eval.data import read_problems
  7. from .human_eval.evaluation import estimate_pass_at_k
  8. from .human_eval.execution import check_correctness
  9. class HumanEvalEvaluator:
  10. def __init__(
  11. self,
  12. language,
  13. problem_file,
  14. tokenizer,
  15. n_workers: int = 4,
  16. timeout: float = 3.0,
  17. ):
  18. self.language = language
  19. self.n_workers = n_workers
  20. self.timeout = timeout
  21. self.problems = read_problems(problem_file)
  22. self.tokenizer = tokenizer
  23. self.total = None
  24. self.correct = None
  25. self.results = {}
  26. def evaluate_pass_k(self, prediction, data, k):
  27. if self.total is None or self.correct is None or self.results is None:
  28. self.evaluate_functional_correctness(prediction, data)
  29. return estimate_pass_at_k(self.total, self.correct, k).mean()
  30. def evaluate_functional_correctness(self, prediction, data):
  31. # Check the generated samples against test suites.
  32. with ThreadPoolExecutor(max_workers=self.n_workers) as executor:
  33. futures = []
  34. completion_id = Counter()
  35. n_samples = 0
  36. results = defaultdict(list)
  37. print_rank_0("Reading samples...")
  38. for i, sample in enumerate(tqdm.tqdm(data)):
  39. task_id = sample["task_id"]
  40. completion = self.tokenizer.tokenizer.decode(prediction[i])
  41. args = (self.problems[task_id], completion, self.timeout, completion_id[task_id])
  42. future = executor.submit(check_correctness, *args)
  43. futures.append(future)
  44. completion_id[task_id] += 1
  45. n_samples += 1
  46. assert len(completion_id) == len(self.problems), "Some problems are not attempted."
  47. print_rank_0("Running test suites...")
  48. for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
  49. result = future.result()
  50. results[result["task_id"]].append((result["completion_id"], result))
  51. # Calculate pass@k.
  52. total, correct = [], []
  53. for result in results.values():
  54. result.sort()
  55. passed = [r[1]["passed"] for r in result]
  56. total.append(len(passed))
  57. correct.append(sum(passed))
  58. self.total = np.array(total)
  59. self.correct = np.array(correct)
  60. self.results = results