2
0

metrics.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import string
  2. import re
  3. import functools
  4. from collections import Counter
  5. from SwissArmyTransformer import get_tokenizer
  6. def accuracy_metric(predictions, examples):
  7. count = 0
  8. num_predictions = max(len(predictions), 1)
  9. assert len(predictions) == len(examples)
  10. for prediction, example in zip(predictions, examples):
  11. count += prediction == example["label"]
  12. return count * 100.0 / num_predictions
  13. def normalize_answer(s):
  14. """Lower text and remove punctuation, articles and extra whitespace."""
  15. def remove_articles(text):
  16. return re.sub(r"\b(a|an|the)\b", " ", text)
  17. def white_space_fix(text):
  18. return " ".join(text.split())
  19. def remove_punc(text):
  20. exclude = set(string.punctuation)
  21. return "".join(ch for ch in text if ch not in exclude)
  22. def lower(text):
  23. return text.lower()
  24. return white_space_fix(remove_articles(remove_punc(lower(s))))
  25. def f1_score(prediction, ground_truth):
  26. prediction_tokens = normalize_answer(prediction).split()
  27. ground_truth_tokens = normalize_answer(ground_truth).split()
  28. common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
  29. num_same = sum(common.values())
  30. if num_same == 0:
  31. return 0
  32. precision = 1.0 * num_same / len(prediction_tokens)
  33. recall = 1.0 * num_same / len(ground_truth_tokens)
  34. f1 = (2 * precision * recall) / (precision + recall)
  35. return f1
  36. def exact_match_score(prediction, ground_truth):
  37. return normalize_answer(prediction) == normalize_answer(ground_truth)
  38. def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
  39. if not ground_truths:
  40. return 0.0
  41. scores_for_ground_truths = []
  42. for ground_truth in ground_truths:
  43. score = metric_fn(prediction, ground_truth)
  44. scores_for_ground_truths.append(score)
  45. return max(scores_for_ground_truths)
  46. def qa_evaluate(predictions, examples, metric):
  47. assert len(examples) == len(predictions)
  48. tokenizer = get_tokenizer()
  49. score = 0.0
  50. for example, prediction in zip(examples, predictions):
  51. ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
  52. prediction = tokenizer.tokenizer.decode(prediction)
  53. if ground_truths:
  54. score += metric_max_over_ground_truths(metric, prediction, ground_truths)
  55. score = 100.0 * score / len(predictions)
  56. return score
  57. qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
  58. qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
  59. DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric}