2
0

task.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from string import punctuation
  2. from functools import partial
  3. from typing import List
  4. from SwissArmyTransformer import mpu
  5. import numpy as np
  6. import torch
  7. import os
  8. from tqdm import tqdm
  9. from evaluation import qa_evaluate, GenerationTask
  10. from collections import defaultdict
  11. from typing import Dict, Tuple
  12. from rouge_score import rouge_scorer
  13. from bleurt import score
  14. from evaluation.utils import (
  15. print_rank_0,
  16. get_tokenized_input,
  17. )
  18. os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
  19. class WEB(GenerationTask):
  20. def __init__(self, model, tokenizer, config_path):
  21. super(WEB, self).__init__(model, tokenizer, config_path)
  22. self.bleurt_checkpoint = "BLEURT CHECKPOINT PATH"
  23. def WEBMetric(self, predictions, examples):
  24. metrics_dict = defaultdict(lambda: [])
  25. scorer_rouge = rouge_scorer.RougeScorer(["rouge2", "rougeL"], use_stemmer=True)
  26. scorer_bleurt = score.BleurtScorer(self.bleurt_checkpoint)
  27. for text, target in tqdm(zip(predictions, examples)):
  28. text_de = self.tokenizer.detokenize(text)
  29. target_de = self.tokenizer.detokenize(target["targets"][0])
  30. scores_rouge = scorer_rouge.score(text_de, target_de)
  31. scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])
  32. rouge2_precision = scores_rouge["rouge2"].precision
  33. rouge2_recall = scores_rouge["rouge2"].recall
  34. rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
  35. rougeL_precision = scores_rouge["rougeL"].precision
  36. rougeL_recall = scores_rouge["rougeL"].recall
  37. rougeL_fmeasure = scores_rouge["rougeL"].fmeasure
  38. metrics_dict["rouge2_precision"].append(rouge2_precision)
  39. metrics_dict["rouge2_recall"].append(rouge2_recall)
  40. metrics_dict["rouge2_fmeasure"].append(rouge2_fmeasure)
  41. metrics_dict["rougeL_precision"].append(rougeL_precision)
  42. metrics_dict["rougeL_recall"].append(rougeL_recall)
  43. metrics_dict["rougeL_fmeasure"].append(rougeL_fmeasure)
  44. metrics_dict["bleurt"].append(scores_bleurt[0])
  45. return metrics_dict
  46. @property
  47. def metrics(self):
  48. return {"e2e": self.WEBMetric}
  49. def predict_single_batch(self, batch) -> List[List[int]]:
  50. output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
  51. return output
  52. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  53. pass
  54. def report_group_metrics(
  55. self,
  56. group_name,
  57. result_dict_group: Dict[str, Tuple[Dict[str, float], int]],
  58. level=1,
  59. ):
  60. print("report")
  61. for tmp1 in result_dict_group.values():
  62. tmp1 = tmp1[0]
  63. for result in tmp1.values():
  64. for key, values in result.items():
  65. print_rank_0(key, np.mean(values))