task.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from string import punctuation
  2. from functools import partial
  3. from typing import List
  4. from SwissArmyTransformer import mpu
  5. import numpy as np
  6. import torch
  7. from tqdm import tqdm
  8. from evaluation import qa_evaluate, GenerationTask
  9. from collections import defaultdict
  10. from typing import Dict,Tuple
  11. from rouge_score import rouge_scorer
  12. from bleurt import score
  13. from evaluation.utils import (
  14. print_rank_0,
  15. get_tokenized_input,
  16. )
  17. class WEB(GenerationTask):
  18. def __init__(self, model, tokenizer, config_path):
  19. super(WEB, self).__init__(model, tokenizer, config_path)
  20. self.bleurt_checkpoint = "BLEURT-CHECKPOINT PATH"
  21. def WEBMetric(self, predictions, examples):
  22. metrics_dict = defaultdict(lambda: [])
  23. import os
  24. os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
  25. scorer_rouge = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)
  26. scorer_bleurt = score.BleurtScorer(self.bleurt_checkpoint)
  27. for text,target in tqdm(zip(predictions, examples)):
  28. text_de = self.tokenizer.detokenize(text)
  29. target_de = self.tokenizer.detokenize(target["targets"][0])
  30. scores_rouge = scorer_rouge.score(text_de,target_de)
  31. scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])
  32. rouge2_precision = scores_rouge["rouge2"].precision
  33. rouge2_recall = scores_rouge["rouge2"].recall
  34. rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
  35. rougeL_precision = scores_rouge["rougeL"].precision
  36. rougeL_recall = scores_rouge["rougeL"].recall
  37. rougeL_fmeasure = scores_rouge["rougeL"].fmeasure
  38. metrics_dict["rouge2_precision"].append(rouge2_precision)
  39. metrics_dict["rouge2_recall"].append(rouge2_recall)
  40. metrics_dict["rouge2_fmeasure"].append(rouge2_fmeasure)
  41. metrics_dict["rougeL_precision"].append(rougeL_precision)
  42. metrics_dict["rougeL_recall"].append(rougeL_recall)
  43. metrics_dict["rougeL_fmeasure"].append(rougeL_fmeasure)
  44. metrics_dict["bleurt"].append(scores_bleurt[0])
  45. return metrics_dict
  46. @property
  47. def metrics(self):
  48. return {"e2e": self.WEBMetric}
  49. def predict_single_batch(self, batch) -> List[List[int]]:
  50. output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
  51. return output
  52. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  53. pass
  54. def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
  55. print("report")
  56. for tmp1 in result_dict_group.values():
  57. tmp1 = tmp1[0]
  58. for result in tmp1.values():
  59. for key,values in result.items():
  60. print_rank_0(key,np.mean(values))