Browse Source

Add prompt type

duzx16 2 years ago
parent
commit
69a3657011
1 changed files with 20 additions and 6 deletions
  1. 20 6
      tasks/cot/task.py

+ 20 - 6
tasks/cot/task.py

@@ -12,6 +12,7 @@ from dataclasses import dataclass
 class ChainOfThoughtConfig(GenerationTaskConfig):
 class ChainOfThoughtConfig(GenerationTaskConfig):
     prompt_path: str = None
     prompt_path: str = None
     chain_of_thought: bool = True
     chain_of_thought: bool = True
+    prompt_type: str = None
 
 
 
 
 def read_examples(prompt_path):
 def read_examples(prompt_path):
@@ -33,14 +34,20 @@ def read_examples(prompt_path):
     return examples
     return examples
 
 
 
 
-def build_prompt(examples, task_name, chain_of_thought=True):
+def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
     prompts = []
     prompts = []
-    for item in examples:
+    for i, item in enumerate(examples):
         question, answer = item["question"], item["answer"]
         question, answer = item["question"], item["answer"]
         if not chain_of_thought:
         if not chain_of_thought:
             answer = extract_answer(answer, task_name)
             answer = extract_answer(answer, task_name)
-        prompts.append(f"Question: {question} Answer: {answer}")
-    prompt = " ".join(prompts)
+        if prompt_type == "number":
+            prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
+        else:
+            prompts.append(f"Question: {question} Answer: {answer}")
+    if prompt_type == "return":
+        prompt = " <n>".join(prompts)
+    else:
+        prompt = " ".join(prompts)
     return prompt
     return prompt
 
 
 
 
@@ -73,17 +80,24 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
 
 
 
 
 class ChainOfThoughtDataset(GenerationTaskDataset):
 class ChainOfThoughtDataset(GenerationTaskDataset):
+    config: ChainOfThoughtConfig
 
 
     def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
     def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
         self.labeled_examples = read_examples(config.prompt_path)
         self.labeled_examples = read_examples(config.prompt_path)
-        self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought)
+        self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
+                                           prompt_type=config.prompt_type)
         print_rank_0(self.labeled_prompt)
         print_rank_0(self.labeled_prompt)
         self.printed_count = 0
         self.printed_count = 0
         super().__init__(path, config)
         super().__init__(path, config)
 
 
     def process_single_item(self, item, **kwargs):
     def process_single_item(self, item, **kwargs):
         question, targets = item["question"], item["targets"]
         question, targets = item["question"], item["targets"]
-        text = self.labeled_prompt + f" Question: {question} Answer:"
+        if self.config.prompt_type == "number":
+            text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
+        elif self.config.prompt_type == "return":
+            text = self.labeled_prompt + f" <n>Question: {question} Answer:"
+        else:
+            text = self.labeled_prompt + f" Question: {question} Answer:"
         text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
         text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2