evaluation.py 859 B

1234567891011121314151617181920212223242526272829
  1. from typing import List, Union
  2. import itertools
  3. import numpy as np
  4. def estimate_pass_at_k(
  5. num_samples: Union[int, List[int], np.ndarray],
  6. num_correct: Union[List[int], np.ndarray],
  7. k: int
  8. ) -> np.ndarray:
  9. """
  10. Estimates pass@k of each problem and returns them in an array.
  11. """
  12. def estimator(n: int, c: int, k: int) -> float:
  13. """
  14. Calculates 1 - comb(n - c, k) / comb(n, k).
  15. """
  16. if n - c < k:
  17. return 1.0
  18. return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
  19. if isinstance(num_samples, int):
  20. num_samples_it = itertools.repeat(num_samples, len(num_correct))
  21. else:
  22. assert len(num_samples) == len(num_correct)
  23. num_samples_it = iter(num_samples)
  24. return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])