task.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from string import punctuation
  2. from functools import partial
  3. from typing import List
  4. from SwissArmyTransformer import mpu
  5. import numpy as np
  6. import torch
  7. import os
  8. from tqdm import tqdm
  9. from evaluation import qa_evaluate, GenerationTask
  10. from collections import defaultdict
  11. from typing import Dict, Tuple
  12. from rouge_score import rouge_scorer
  13. from bleurt import score
  14. from evaluation.utils import (
  15. print_rank_0,
  16. get_tokenized_input,
  17. )
  18. os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
  19. class E2E(GenerationTask):
  20. def __init__(self, model, tokenizer, config_path):
  21. super(E2E, self).__init__(model, tokenizer, config_path)
  22. self.bleurt_checkpoint = "BLEURT CHECKPOINT PATH"
  23. def E2EMetric(self, predictions, examples):
  24. metrics_dict = defaultdict(lambda: [])
  25. scorer_rouge = rouge_scorer.RougeScorer(["rouge2", "rougeL"], use_stemmer=True)
  26. scorer_bleurt = score.BleurtScorer(self.bleurt_checkpoint)
  27. for text, target in tqdm(zip(predictions, examples)):
  28. text_de = self.tokenizer.detokenize(text)
  29. target_de = self.tokenizer.detokenize(target["targets"][0])
  30. scores_rouge = scorer_rouge.score(text_de, target_de)
  31. scores_bleurt = scorer_bleurt.score(
  32. references=[target_de], candidates=[text_de]
  33. )
  34. rouge2_precision = scores_rouge["rouge2"].precision
  35. rouge2_recall = scores_rouge["rouge2"].recall
  36. rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
  37. rougeL_precision = scores_rouge["rougeL"].precision
  38. rougeL_recall = scores_rouge["rougeL"].recall
  39. rougeL_fmeasure = scores_rouge["rougeL"].fmeasure
  40. metrics_dict["rouge2_precision"].append(rouge2_precision)
  41. metrics_dict["rouge2_recall"].append(rouge2_recall)
  42. metrics_dict["rouge2_fmeasure"].append(rouge2_fmeasure)
  43. metrics_dict["rougeL_precision"].append(rougeL_precision)
  44. metrics_dict["rougeL_recall"].append(rougeL_recall)
  45. metrics_dict["rougeL_fmeasure"].append(rougeL_fmeasure)
  46. metrics_dict["bleurt"].append(scores_bleurt[0])
  47. return metrics_dict
  48. @property
  49. def metrics(self):
  50. return {"e2e": self.E2EMetric}
  51. def predict_single_batch(self, batch) -> List[List[int]]:
  52. output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
  53. return output
  54. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  55. pass
  56. def report_group_metrics(
  57. self,
  58. group_name,
  59. result_dict_group: Dict[str, Tuple[Dict[str, float], int]],
  60. level=1,
  61. ):
  62. print("report")
  63. for tmp1 in result_dict_group.values():
  64. tmp1 = tmp1[0]
  65. for result in tmp1.values():
  66. for key, values in result.items():
  67. print_rank_0(key, np.mean(values))