2
0

metrics.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import re
  2. import math
  3. import string
  4. import functools
  5. import torch
  6. import numpy as np
  7. from typing import Tuple, List
  8. from collections import Counter
  9. from collections import defaultdict
  10. from SwissArmyTransformer import get_tokenizer
  11. from .utils import print_rank_0
  12. def accuracy_metric(predictions, examples):
  13. count = 0
  14. num_predictions = max(len(predictions), 1)
  15. assert len(predictions) == len(examples)
  16. for prediction, example in zip(predictions, examples):
  17. count += prediction == example["label"]
  18. return count * 100.0 / num_predictions
  19. def F1_metric(predictions, examples):
  20. assert len(predictions) == len(examples)
  21. from sklearn.metrics import f1_score
  22. truth = []
  23. for prediction, example in zip(predictions, examples):
  24. truth.append(example["label"])
  25. return f1_score(truth, predictions, average="micro") * 100.0
  26. def precision_metric(predictions, examples):
  27. assert len(predictions) == len(examples)
  28. from sklearn.metrics import precision_score
  29. truth = []
  30. for prediction, example in zip(predictions, examples):
  31. truth.append(example["label"])
  32. return precision_score(truth, predictions, average="micro") * 100.0
  33. def recall_metric(predictions, examples):
  34. assert len(predictions) == len(examples)
  35. from sklearn.metrics import recall_score
  36. truth = []
  37. for prediction, example in zip(predictions, examples):
  38. truth.append(example["label"])
  39. return recall_score(truth, predictions, average="micro") * 100.0
  40. def normalize_answer(s):
  41. """Lower text and remove punctuation, articles and extra whitespace."""
  42. def remove_articles(text):
  43. return re.sub(r"\b(a|an|the)\b", " ", text)
  44. def white_space_fix(text):
  45. return " ".join(text.split())
  46. def remove_punc(text):
  47. exclude = set(string.punctuation)
  48. return "".join(ch for ch in text if ch not in exclude)
  49. def lower(text):
  50. return text.lower()
  51. return white_space_fix(remove_articles(remove_punc(lower(s))))
  52. def f1_score(prediction, ground_truth):
  53. prediction_tokens = normalize_answer(prediction).split()
  54. ground_truth_tokens = normalize_answer(ground_truth).split()
  55. common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
  56. num_same = sum(common.values())
  57. if num_same == 0:
  58. return 0
  59. precision = 1.0 * num_same / len(prediction_tokens)
  60. recall = 1.0 * num_same / len(ground_truth_tokens)
  61. f1 = (2 * precision * recall) / (precision + recall)
  62. return f1
  63. def exact_match_score(prediction, ground_truth):
  64. return normalize_answer(prediction) == normalize_answer(ground_truth)
  65. def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
  66. if not ground_truths:
  67. return 0.0
  68. scores_for_ground_truths = []
  69. for ground_truth in ground_truths:
  70. score = metric_fn(prediction, ground_truth)
  71. scores_for_ground_truths.append(score)
  72. return max(scores_for_ground_truths)
  73. def qa_evaluate(predictions, examples, metric):
  74. assert len(examples) == len(predictions)
  75. tokenizer = get_tokenizer()
  76. score = 0.0
  77. for example, prediction in zip(examples, predictions):
  78. ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
  79. prediction = tokenizer.tokenizer.decode(prediction)
  80. if ground_truths:
  81. score += metric_max_over_ground_truths(metric, prediction, ground_truths)
  82. score = 100.0 * score / len(predictions)
  83. return score
  84. qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
  85. qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
  86. def calculate_perplexity(loss: List[float], data):
  87. return math.exp(min(20, np.sum(loss) / data[0]["num_original_tokens"]))
  88. def special_for_dataset(predictions, examples):
  89. print_rank_0("Metrics not found, maybe dataset special metric or metric name error")
  90. return True
  91. DEFAULT_METRICS = defaultdict(lambda: special_for_dataset)
  92. DEFAULT_METRICS.update(
  93. {
  94. "EM": qa_exact_match,
  95. "F1": qa_f1,
  96. "Accuracy": accuracy_metric,
  97. "PPL": calculate_perplexity,
  98. "Precision": precision_metric,
  99. "Recall": recall_metric,
  100. "F1_mul": F1_metric,
  101. }
  102. )