12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import string
- import re
- import functools
- from collections import Counter
- from SwissArmyTransformer import get_tokenizer
- def accuracy_metric(predictions, examples):
- count = 0
- num_predictions = max(len(predictions), 1)
- assert len(predictions) == len(examples)
- for prediction, example in zip(predictions, examples):
- count += prediction == example["label"]
- return count * 100.0 / num_predictions
- def normalize_answer(s):
- """Lower text and remove punctuation, articles and extra whitespace."""
- def remove_articles(text):
- return re.sub(r"\b(a|an|the)\b", " ", text)
- def white_space_fix(text):
- return " ".join(text.split())
- def remove_punc(text):
- exclude = set(string.punctuation)
- return "".join(ch for ch in text if ch not in exclude)
- def lower(text):
- return text.lower()
- return white_space_fix(remove_articles(remove_punc(lower(s))))
- def f1_score(prediction, ground_truth):
- prediction_tokens = normalize_answer(prediction).split()
- ground_truth_tokens = normalize_answer(ground_truth).split()
- common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
- num_same = sum(common.values())
- if num_same == 0:
- return 0
- precision = 1.0 * num_same / len(prediction_tokens)
- recall = 1.0 * num_same / len(ground_truth_tokens)
- f1 = (2 * precision * recall) / (precision + recall)
- return f1
- def exact_match_score(prediction, ground_truth):
- return normalize_answer(prediction) == normalize_answer(ground_truth)
- def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
- if not ground_truths:
- return 0.0
- scores_for_ground_truths = []
- for ground_truth in ground_truths:
- score = metric_fn(prediction, ground_truth)
- scores_for_ground_truths.append(score)
- return max(scores_for_ground_truths)
- def qa_evaluate(predictions, examples, metric):
- assert len(examples) == len(predictions)
- tokenizer = get_tokenizer()
- score = 0.0
- for example, prediction in zip(examples, predictions):
- ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
- prediction = tokenizer.tokenizer.decode(prediction)
- if ground_truths:
- score += metric_max_over_ground_truths(metric, prediction, ground_truths)
- score = 100.0 * score / len(predictions)
- return score
- qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
- qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
- DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric}
|