123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import re
- import math
- import string
- import functools
- import numpy as np
- from typing import List
- from collections import Counter
- from collections import defaultdict
- from SwissArmyTransformer import get_tokenizer
- from .utils import print_rank_0
- def accuracy_metric(predictions, examples):
- count = 0
- num_predictions = max(len(predictions), 1)
- assert len(predictions) == len(examples)
- for prediction, example in zip(predictions, examples):
- count += prediction == example["label"]
- return count * 100.0 / num_predictions
- def F1_metric(predictions, examples):
- assert len(predictions) == len(examples)
- from sklearn.metrics import f1_score
- truth = []
- for prediction, example in zip(predictions, examples):
- truth.append(example["label"])
- return f1_score(truth, predictions, average="micro") * 100.0
- def precision_metric(predictions, examples):
- assert len(predictions) == len(examples)
- from sklearn.metrics import precision_score
- truth = []
- for prediction, example in zip(predictions, examples):
- truth.append(example["label"])
- return precision_score(truth, predictions, average="micro") * 100.0
- def recall_metric(predictions, examples):
- assert len(predictions) == len(examples)
- from sklearn.metrics import recall_score
- truth = []
- for prediction, example in zip(predictions, examples):
- truth.append(example["label"])
- return recall_score(truth, predictions, average="micro") * 100.0
- def normalize_answer(s):
- """Lower text and remove punctuation, articles and extra whitespace."""
- def remove_articles(text):
- return re.sub(r"\b(a|an|the)\b", " ", text)
- def white_space_fix(text):
- return " ".join(text.split())
- def remove_punc(text):
- exclude = set(string.punctuation)
- return "".join(ch for ch in text if ch not in exclude)
- def lower(text):
- return text.lower()
- return white_space_fix(remove_articles(remove_punc(lower(s))))
- def f1_score(prediction, ground_truth):
- prediction_tokens = normalize_answer(prediction).split()
- ground_truth_tokens = normalize_answer(ground_truth).split()
- common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
- num_same = sum(common.values())
- if num_same == 0:
- return 0
- precision = 1.0 * num_same / len(prediction_tokens)
- recall = 1.0 * num_same / len(ground_truth_tokens)
- f1 = (2 * precision * recall) / (precision + recall)
- return f1
- def exact_match_score(prediction, ground_truth):
- return normalize_answer(prediction) == normalize_answer(ground_truth)
- def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
- if not ground_truths:
- return 0.0
- scores_for_ground_truths = []
- for ground_truth in ground_truths:
- score = metric_fn(prediction, ground_truth)
- scores_for_ground_truths.append(score)
- return max(scores_for_ground_truths)
- def qa_evaluate(predictions, examples, metric):
- assert len(examples) == len(predictions)
- tokenizer = get_tokenizer()
- score = 0.0
- for example, prediction in zip(examples, predictions):
- ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
- prediction = tokenizer.tokenizer.decode(prediction)
- if ground_truths:
- score += metric_max_over_ground_truths(metric, prediction, ground_truths)
- score = 100.0 * score / len(predictions)
- return score
- qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
- qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
- def calculate_perplexity(loss: List[float], data):
- return math.exp(min(20, np.sum(loss) / data[0]["num_original_tokens"]))
- def special_for_dataset(predictions, examples):
- print_rank_0("Metrics not found, maybe dataset special metric or metric name error")
- return True
- DEFAULT_METRICS = defaultdict(lambda: special_for_dataset)
- DEFAULT_METRICS.update(
- {
- "EM": qa_exact_match,
- "F1": qa_f1,
- "Accuracy": accuracy_metric,
- "PPL": calculate_perplexity,
- "Precision": precision_metric,
- "Recall": recall_metric,
- "F1_mul": F1_metric,
- }
- )
|