Prechádzať zdrojové kódy

e2e+web-nlg+wiki-lingua

xuyifanbupt 2 rokov pred
rodič
commit
0c5621329f

+ 11 - 0
tasks/clean-e2e-nlg/clean-e2e-nlg.yaml

@@ -0,0 +1,11 @@
+name: "e2eNLGEgen"
+type: "gen"
+path: "clean-e2e-nlg/"
+module: "tasks.clean-e2e-nlg.task.E2E"
+file-pattern:
+  test: "**/test.jsonl"
+
+num_beams: 16
+max_gen_length: 64
+use_task_mask: true
+micro-batch-size: 16

+ 80 - 0
tasks/clean-e2e-nlg/task.py

@@ -0,0 +1,80 @@
+from string import punctuation
+from functools import partial
+from typing import List
+from SwissArmyTransformer import mpu
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from evaluation import qa_evaluate, GenerationTask
+from collections import defaultdict
+from typing import Dict,Tuple
+
+
+from rouge_score import rouge_scorer
+from bleurt import score
+
+
+from evaluation.utils import (
+    print_rank_0,
+    get_tokenized_input,
+)
+
+
+
+class E2E(GenerationTask):
+    def __init__(self, model, tokenizer, config_path):
+        super(E2E, self).__init__(model, tokenizer, config_path)
+        self.bleurt_checkpoint = "BLEURT-CHECKPOINT PATH"
+
+
+    
+    def E2EMetric(self, predictions, examples):
+        metrics_dict = defaultdict(lambda: [])
+        import os
+        os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
+        scorer_rouge = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)
+        scorer_bleurt = score.BleurtScorer(self.bleurt_checkpoint)
+        for text,target in tqdm(zip(predictions, examples)):
+            text_de = self.tokenizer.detokenize(text)
+            target_de = self.tokenizer.detokenize(target["targets"][0])
+                        
+            scores_rouge = scorer_rouge.score(text_de,target_de)
+            scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])       
+            rouge2_precision = scores_rouge["rouge2"].precision
+            rouge2_recall = scores_rouge["rouge2"].recall
+            rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
+            rougeL_precision = scores_rouge["rougeL"].precision
+            rougeL_recall = scores_rouge["rougeL"].recall
+            rougeL_fmeasure = scores_rouge["rougeL"].fmeasure
+            metrics_dict["rouge2_precision"].append(rouge2_precision)
+            metrics_dict["rouge2_recall"].append(rouge2_recall)
+            metrics_dict["rouge2_fmeasure"].append(rouge2_fmeasure)
+            metrics_dict["rougeL_precision"].append(rougeL_precision)
+            metrics_dict["rougeL_recall"].append(rougeL_recall)
+            metrics_dict["rougeL_fmeasure"].append(rougeL_fmeasure)        
+            metrics_dict["bleurt"].append(scores_bleurt[0])
+
+        return metrics_dict
+
+    @property
+    def metrics(self):
+        return {"e2e": self.E2EMetric}
+    
+
+    def predict_single_batch(self, batch) -> List[List[int]]:
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
+        return output
+
+    def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
+        pass
+
+    def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
+        print("report")
+        for tmp1 in result_dict_group.values():
+            tmp1 = tmp1[0]
+            for result in tmp1.values():
+                for key,values in result.items():
+                    print_rank_0(key,np.mean(values))
+
+

+ 80 - 0
tasks/web-nlg/task.py

@@ -0,0 +1,80 @@
+from string import punctuation
+from functools import partial
+from typing import List
+from SwissArmyTransformer import mpu
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from evaluation import qa_evaluate, GenerationTask
+from collections import defaultdict
+from typing import Dict,Tuple
+
+
+from rouge_score import rouge_scorer
+from bleurt import score
+
+
+from evaluation.utils import (
+    print_rank_0,
+    get_tokenized_input,
+)
+
+
+
+class WEB(GenerationTask):
+    def __init__(self, model, tokenizer, config_path):
+        super(WEB, self).__init__(model, tokenizer, config_path)
+        self.bleurt_checkpoint = "BLEURT-CHECKPOINT PATH"
+
+
+    
+    def WEBMetric(self, predictions, examples):
+        metrics_dict = defaultdict(lambda: [])
+        import os
+        os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
+        scorer_rouge = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)
+        scorer_bleurt = score.BleurtScorer(self.bleurt_checkpoint)
+        for text,target in tqdm(zip(predictions, examples)):
+            text_de = self.tokenizer.detokenize(text)
+            target_de = self.tokenizer.detokenize(target["targets"][0])
+                        
+            scores_rouge = scorer_rouge.score(text_de,target_de)
+            scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])       
+            rouge2_precision = scores_rouge["rouge2"].precision
+            rouge2_recall = scores_rouge["rouge2"].recall
+            rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
+            rougeL_precision = scores_rouge["rougeL"].precision
+            rougeL_recall = scores_rouge["rougeL"].recall
+            rougeL_fmeasure = scores_rouge["rougeL"].fmeasure
+            metrics_dict["rouge2_precision"].append(rouge2_precision)
+            metrics_dict["rouge2_recall"].append(rouge2_recall)
+            metrics_dict["rouge2_fmeasure"].append(rouge2_fmeasure)
+            metrics_dict["rougeL_precision"].append(rougeL_precision)
+            metrics_dict["rougeL_recall"].append(rougeL_recall)
+            metrics_dict["rougeL_fmeasure"].append(rougeL_fmeasure)        
+            metrics_dict["bleurt"].append(scores_bleurt[0])
+
+        return metrics_dict
+
+    @property
+    def metrics(self):
+        return {"e2e": self.WEBMetric}
+    
+
+    def predict_single_batch(self, batch) -> List[List[int]]:
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
+        return output
+
+    def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
+        pass
+
+    def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
+        print("report")
+        for tmp1 in result_dict_group.values():
+            tmp1 = tmp1[0]
+            for result in tmp1.values():
+                for key,values in result.items():
+                    print_rank_0(key,np.mean(values))
+
+

+ 11 - 0
tasks/web-nlg/web-nlg.yaml

@@ -0,0 +1,11 @@
+name: "web-nlg"
+type: "gen"
+path: "web-nlg/"
+module: "tasks.web-nlg.task.WEB"
+file-pattern:
+  test: "**/test.jsonl"
+
+num_beams: 16
+max_gen_length: 64
+use_task_mask: true
+micro-batch-size: 16

+ 80 - 0
tasks/wiki-lingua/task.py

@@ -0,0 +1,80 @@
+from string import punctuation
+from functools import partial
+from typing import List
+from SwissArmyTransformer import mpu
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from evaluation import qa_evaluate, GenerationTask
+from collections import defaultdict
+from typing import Dict,Tuple
+
+
+from rouge_score import rouge_scorer
+from bleurt import score
+
+
+from evaluation.utils import (
+    print_rank_0,
+    get_tokenized_input,
+)
+
+
+
+class WIKI(GenerationTask):
+    def __init__(self, model, tokenizer, config_path):
+        super(WIKI, self).__init__(model, tokenizer, config_path)
+        self.bleurt_checkpoint = "BLEURT-CHECKPOINT PATH"
+
+
+    
+    def WIKIMetric(self, predictions, examples):
+        metrics_dict = defaultdict(lambda: [])
+        import os
+        os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
+        scorer_rouge = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)
+        scorer_bleurt = score.BleurtScorer(self.bleurt_checkpoint)
+        for text,target in tqdm(zip(predictions, examples)):
+            text_de = self.tokenizer.detokenize(text)
+            target_de = self.tokenizer.detokenize(target["targets"][0])
+                        
+            scores_rouge = scorer_rouge.score(text_de,target_de)
+            scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])       
+            rouge2_precision = scores_rouge["rouge2"].precision
+            rouge2_recall = scores_rouge["rouge2"].recall
+            rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
+            rougeL_precision = scores_rouge["rougeL"].precision
+            rougeL_recall = scores_rouge["rougeL"].recall
+            rougeL_fmeasure = scores_rouge["rougeL"].fmeasure
+            metrics_dict["rouge2_precision"].append(rouge2_precision)
+            metrics_dict["rouge2_recall"].append(rouge2_recall)
+            metrics_dict["rouge2_fmeasure"].append(rouge2_fmeasure)
+            metrics_dict["rougeL_precision"].append(rougeL_precision)
+            metrics_dict["rougeL_recall"].append(rougeL_recall)
+            metrics_dict["rougeL_fmeasure"].append(rougeL_fmeasure)        
+            metrics_dict["bleurt"].append(scores_bleurt[0])
+
+        return metrics_dict
+
+    @property
+    def metrics(self):
+        return {"e2e": self.WIKIMetric}
+    
+
+    def predict_single_batch(self, batch) -> List[List[int]]:
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
+        return output
+
+    def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
+        pass
+
+    def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
+        print("report")
+        for tmp1 in result_dict_group.values():
+            tmp1 = tmp1[0]
+            for result in tmp1.values():
+                for key,values in result.items():
+                    print_rank_0(key,np.mean(values))
+
+

+ 11 - 0
tasks/wiki-lingua/wiki-lingua.yaml

@@ -0,0 +1,11 @@
+name: "wiki-lingua_en2en"
+type: "gen"
+path: "wiki-lingua/"
+module: "tasks.wiki-lingua.task.WIKI"
+file-pattern:
+  test: "**/test.jsonl"
+
+num_beams: 16
+max_gen_length: 64
+use_task_mask: true
+micro-batch-size: 8