metrics.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import re
  2. import math
  3. import string
  4. import functools
  5. import numpy as np
  6. from typing import Tuple, List
  7. from collections import Counter
  8. from SwissArmyTransformer import get_tokenizer
  9. def accuracy_metric(predictions, examples):
  10. count = 0
  11. num_predictions = max(len(predictions), 1)
  12. assert len(predictions) == len(examples)
  13. for prediction, example in zip(predictions, examples):
  14. count += prediction == example["label"]
  15. return count * 100.0 / num_predictions
  16. def normalize_answer(s):
  17. """Lower text and remove punctuation, articles and extra whitespace."""
  18. def remove_articles(text):
  19. return re.sub(r"\b(a|an|the)\b", " ", text)
  20. def white_space_fix(text):
  21. return " ".join(text.split())
  22. def remove_punc(text):
  23. exclude = set(string.punctuation)
  24. return "".join(ch for ch in text if ch not in exclude)
  25. def lower(text):
  26. return text.lower()
  27. return white_space_fix(remove_articles(remove_punc(lower(s))))
  28. def f1_score(prediction, ground_truth):
  29. prediction_tokens = normalize_answer(prediction).split()
  30. ground_truth_tokens = normalize_answer(ground_truth).split()
  31. common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
  32. num_same = sum(common.values())
  33. if num_same == 0:
  34. return 0
  35. precision = 1.0 * num_same / len(prediction_tokens)
  36. recall = 1.0 * num_same / len(ground_truth_tokens)
  37. f1 = (2 * precision * recall) / (precision + recall)
  38. return f1
  39. def exact_match_score(prediction, ground_truth):
  40. return normalize_answer(prediction) == normalize_answer(ground_truth)
  41. def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
  42. if not ground_truths:
  43. return 0.0
  44. scores_for_ground_truths = []
  45. for ground_truth in ground_truths:
  46. score = metric_fn(prediction, ground_truth)
  47. scores_for_ground_truths.append(score)
  48. return max(scores_for_ground_truths)
  49. def qa_evaluate(predictions, examples, metric):
  50. assert len(examples) == len(predictions)
  51. tokenizer = get_tokenizer()
  52. score = 0.0
  53. for example, prediction in zip(examples, predictions):
  54. ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
  55. prediction = tokenizer.tokenizer.decode(prediction)
  56. if ground_truths:
  57. score += metric_max_over_ground_truths(metric, prediction, ground_truths)
  58. score = 100.0 * score / len(predictions)
  59. return score
  60. qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
  61. qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
  62. def calculate_perplexity(loss: List[float], data):
  63. return math.exp(min(20, np.sum(loss) / data[0]["num_original_tokens"]))
  64. DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric, "PPL": calculate_perplexity}