task.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import numpy as np
  2. from typing import Dict, Tuple
  3. from evaluation import MultiChoiceTask
  4. categories = {
  5. "STEM": [
  6. "Abstract Algebra",
  7. "Anatomy",
  8. "Astronomy",
  9. "College Biology",
  10. "College Chemistry",
  11. "College Computer Science",
  12. "College Mathematics",
  13. "College Physics",
  14. "Computer Security",
  15. "Conceptual Physics",
  16. "Electrical Engineering",
  17. "Elementary Mathematics",
  18. "High School Biology",
  19. "High School Chemistry",
  20. "High School Computer Science",
  21. "High School Mathematics",
  22. "High School Physics",
  23. "High School Statistics",
  24. "Machine Learning",
  25. ],
  26. "Other": [
  27. "Business Ethics",
  28. "Clinical Knowledge",
  29. "College Medicine",
  30. "Global Facts",
  31. "Human Aging",
  32. "Management",
  33. "Marketing",
  34. "Medical Genetics",
  35. "Miscellaneous",
  36. "Nutrition",
  37. "Professional Accounting",
  38. "Professional Medicine",
  39. "Virology",
  40. ],
  41. "Social Sciences": [
  42. "Econometrics",
  43. "High School Geography",
  44. "High School Government and Politics",
  45. "High School Macroeconomics",
  46. "High School Microeconomics",
  47. "High School Psychology",
  48. "Human Sexuality",
  49. "Professional Psychology",
  50. "Public Relations",
  51. "Security Studies",
  52. "Sociology",
  53. "US Foreign Policy",
  54. ],
  55. "Humanities": [
  56. "Formal Logic",
  57. "High School European History",
  58. "High School US History",
  59. "High School World History",
  60. "International Law",
  61. "Jurisprudence",
  62. "Logical Fallacies",
  63. "Moral Disputes",
  64. "Moral Scenarios",
  65. "Philosophy",
  66. "Prehistory",
  67. "Professional Law",
  68. "World Religions",
  69. ],
  70. }
  71. class MMLU(MultiChoiceTask):
  72. def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
  73. self.report_group_metrics("Overall", result_dict_all, level=0)