2
0
Xiao Xia 2 жил өмнө
parent
commit
4250ca1b44

+ 1 - 0
.gitignore

@@ -3,3 +3,4 @@ __pycache__
 samples
 .DS_Store
 .idea
+outputs

+ 10 - 0
tasks/bbh/bbh-cot.yaml

@@ -0,0 +1,10 @@
+name: 'Big-bench Hard (CoT)'
+type: 'gen'
+path: 'bbh_cot/cot'
+module: 'tasks.bbh.task.BBHGeneration'
+file-pattern:
+  test: "**/*.json*"
+micro-batch-size: 8
+max_gen_length: 400
+save-prediction: True
+use_task_mask: True

+ 10 - 0
tasks/bbh/bbh-generation.yaml

@@ -0,0 +1,10 @@
+name: 'Big-bench Hard Generation'
+type: 'gen'
+path: 'bbh_cot/direct/generation'
+module: 'tasks.bbh.task.BBHGeneration'
+file-pattern:
+  test: "**/*.json*"
+micro-batch-size: 8
+max_gen_length: 400
+save-prediction: True
+use_task_mask: True

+ 8 - 0
tasks/bbh/bbh-multichoice.yaml

@@ -0,0 +1,8 @@
+name: 'Big-bench Hard Multichoice'
+type: 'mul'
+path: 'bbh_cot/direct/multichoice'
+module: 'tasks.bbh.task.BBHMultichoice'
+file-pattern:
+  test: "**/*.json*"
+micro-batch-size: 8
+use_task_mask: True

+ 35 - 0
tasks/bbh/task.py

@@ -0,0 +1,35 @@
+import os
+import re
+
+from datetime import datetime
+from functools import partial
+
+from evaluation import qa_evaluate, GenerationTask
+
+def extract_answer(prediction):
+    pattern = r"(?<=(the|The) answer is ).*?(?=\.\n)"
+    match = re.search(pattern, prediction)
+    if match:
+        answer = match.group(0)
+    else:
+        answer = ""
+    return answer
+
+def exact_match_score(prediction, ground_truth):
+    return extract_answer(prediction) == ground_truth
+
+class BBHGeneration(GenerationTask):
+    @property
+    def metrics(self):
+        return {"Accuracy": partial(qa_evaluate, metric=exact_match_score)}
+
+    def __init__(self, model, tokenizer, config):
+        super(BBHGeneration, self).__init__(model, tokenizer, config)
+        self.start_time = datetime.now()
+
+    def save_prediction_to_file(self, file, prediction, data):
+        filename = os.path.join(f"outputs_{self.start_time}", self.config.name, f"{file}.predict")
+        os.makedirs(os.path.dirname(filename), exist_ok=True)
+        with open(filename, "w") as file:
+            for item in prediction:
+                file.write(str([self.tokenizer.detokenize(item)]) + "\n")