2
0

metrics.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import re
  2. import math
  3. import string
  4. import functools
  5. import numpy as np
  6. from typing import List
  7. from collections import Counter
  8. from collections import defaultdict
  9. from SwissArmyTransformer import get_tokenizer
  10. from .utils import print_rank_0
  11. def accuracy_metric(predictions, examples):
  12. count = 0
  13. num_predictions = max(len(predictions), 1)
  14. assert len(predictions) == len(examples)
  15. for prediction, example in zip(predictions, examples):
  16. count += prediction == example["label"]
  17. return count * 100.0 / num_predictions
  18. def F1_metric(predictions, examples):
  19. assert len(predictions) == len(examples)
  20. from sklearn.metrics import f1_score
  21. truth = []
  22. for prediction, example in zip(predictions, examples):
  23. truth.append(example["label"])
  24. return f1_score(truth, predictions, average="micro") * 100.0
  25. def precision_metric(predictions, examples):
  26. assert len(predictions) == len(examples)
  27. from sklearn.metrics import precision_score
  28. truth = []
  29. for prediction, example in zip(predictions, examples):
  30. truth.append(example["label"])
  31. return precision_score(truth, predictions, average="micro") * 100.0
  32. def recall_metric(predictions, examples):
  33. assert len(predictions) == len(examples)
  34. from sklearn.metrics import recall_score
  35. truth = []
  36. for prediction, example in zip(predictions, examples):
  37. truth.append(example["label"])
  38. return recall_score(truth, predictions, average="micro") * 100.0
  39. def normalize_answer(s):
  40. """Lower text and remove punctuation, articles and extra whitespace."""
  41. def remove_articles(text):
  42. return re.sub(r"\b(a|an|the)\b", " ", text)
  43. def white_space_fix(text):
  44. return " ".join(text.split())
  45. def remove_punc(text):
  46. exclude = set(string.punctuation)
  47. return "".join(ch for ch in text if ch not in exclude)
  48. def lower(text):
  49. return text.lower()
  50. return white_space_fix(remove_articles(remove_punc(lower(s))))
  51. def f1_score(prediction, ground_truth):
  52. prediction_tokens = normalize_answer(prediction).split()
  53. ground_truth_tokens = normalize_answer(ground_truth).split()
  54. common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
  55. num_same = sum(common.values())
  56. if num_same == 0:
  57. return 0
  58. precision = 1.0 * num_same / len(prediction_tokens)
  59. recall = 1.0 * num_same / len(ground_truth_tokens)
  60. f1 = (2 * precision * recall) / (precision + recall)
  61. return f1
  62. def exact_match_score(prediction, ground_truth):
  63. return normalize_answer(prediction) == normalize_answer(ground_truth)
  64. def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
  65. if not ground_truths:
  66. return 0.0
  67. scores_for_ground_truths = []
  68. for ground_truth in ground_truths:
  69. score = metric_fn(prediction, ground_truth)
  70. scores_for_ground_truths.append(score)
  71. return max(scores_for_ground_truths)
  72. def qa_evaluate(predictions, examples, metric):
  73. assert len(examples) == len(predictions)
  74. tokenizer = get_tokenizer()
  75. score = 0.0
  76. for example, prediction in zip(examples, predictions):
  77. ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
  78. prediction = tokenizer.tokenizer.decode(prediction)
  79. if ground_truths:
  80. score += metric_max_over_ground_truths(metric, prediction, ground_truths)
  81. score = 100.0 * score / len(predictions)
  82. return score
  83. qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
  84. qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
  85. def calculate_perplexity(loss: List[float], data):
  86. return math.exp(min(20, np.sum(loss) / data[0]["num_original_tokens"]))
  87. def special_for_dataset(predictions, examples):
  88. print_rank_0("Metrics not found, maybe dataset special metric or metric name error")
  89. return True
  90. DEFAULT_METRICS = defaultdict(lambda: special_for_dataset)
  91. DEFAULT_METRICS.update(
  92. {
  93. "EM": qa_exact_match,
  94. "F1": qa_f1,
  95. "Accuracy": accuracy_metric,
  96. "PPL": calculate_perplexity,
  97. "Precision": precision_metric,
  98. "Recall": recall_metric,
  99. "F1_mul": F1_metric,
  100. }
  101. )