|
@@ -12,6 +12,7 @@ from dataclasses import dataclass
|
|
class ChainOfThoughtConfig(GenerationTaskConfig):
|
|
class ChainOfThoughtConfig(GenerationTaskConfig):
|
|
prompt_path: str = None
|
|
prompt_path: str = None
|
|
chain_of_thought: bool = True
|
|
chain_of_thought: bool = True
|
|
|
|
+ prompt_type: str = None
|
|
|
|
|
|
|
|
|
|
def read_examples(prompt_path):
|
|
def read_examples(prompt_path):
|
|
@@ -33,14 +34,20 @@ def read_examples(prompt_path):
|
|
return examples
|
|
return examples
|
|
|
|
|
|
|
|
|
|
-def build_prompt(examples, task_name, chain_of_thought=True):
|
|
|
|
|
|
+def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
|
|
prompts = []
|
|
prompts = []
|
|
- for item in examples:
|
|
|
|
|
|
+ for i, item in enumerate(examples):
|
|
question, answer = item["question"], item["answer"]
|
|
question, answer = item["question"], item["answer"]
|
|
if not chain_of_thought:
|
|
if not chain_of_thought:
|
|
answer = extract_answer(answer, task_name)
|
|
answer = extract_answer(answer, task_name)
|
|
- prompts.append(f"Question: {question} Answer: {answer}")
|
|
|
|
- prompt = " ".join(prompts)
|
|
|
|
|
|
+ if prompt_type == "number":
|
|
|
|
+ prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
|
|
|
|
+ else:
|
|
|
|
+ prompts.append(f"Question: {question} Answer: {answer}")
|
|
|
|
+ if prompt_type == "return":
|
|
|
|
+ prompt = " <n>".join(prompts)
|
|
|
|
+ else:
|
|
|
|
+ prompt = " ".join(prompts)
|
|
return prompt
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
@@ -73,17 +80,24 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
|
|
|
|
|
|
|
|
|
|
class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
class ChainOfThoughtDataset(GenerationTaskDataset):
|
|
|
|
+ config: ChainOfThoughtConfig
|
|
|
|
|
|
def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
|
|
def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
|
|
self.labeled_examples = read_examples(config.prompt_path)
|
|
self.labeled_examples = read_examples(config.prompt_path)
|
|
- self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought)
|
|
|
|
|
|
+ self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
|
|
|
|
+ prompt_type=config.prompt_type)
|
|
print_rank_0(self.labeled_prompt)
|
|
print_rank_0(self.labeled_prompt)
|
|
self.printed_count = 0
|
|
self.printed_count = 0
|
|
super().__init__(path, config)
|
|
super().__init__(path, config)
|
|
|
|
|
|
def process_single_item(self, item, **kwargs):
|
|
def process_single_item(self, item, **kwargs):
|
|
question, targets = item["question"], item["targets"]
|
|
question, targets = item["question"], item["targets"]
|
|
- text = self.labeled_prompt + f" Question: {question} Answer:"
|
|
|
|
|
|
+ if self.config.prompt_type == "number":
|
|
|
|
+ text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
|
|
|
|
+ elif self.config.prompt_type == "return":
|
|
|
|
+ text = self.labeled_prompt + f" <n>Question: {question} Answer:"
|
|
|
|
+ else:
|
|
|
|
+ text = self.labeled_prompt + f" Question: {question} Answer:"
|
|
text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
|
|
text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
|
|
if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
text_length = self.config.max_seq_length - self.config.max_gen_length - 2
|
|
text_length = self.config.max_seq_length - self.config.max_gen_length - 2
|