Browse Source

Merge pull request #40 from duzx16/cot

Cot
Aohan Zeng 2 years ago
parent
commit
0938e1dc10

+ 0 - 1
generation/strategies.py

@@ -16,7 +16,6 @@ class BaseStrategy:
             end_tokens = []
         self.end_tokens = end_tokens
         self.deterministic = deterministic
-        print(self.deterministic)
         self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
     @property

+ 14 - 0
tasks/cot/coinflip.yaml

@@ -0,0 +1,14 @@
+name: 'coinflip'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'symbolic'
+file-pattern:
+  test: "coinflip.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/coinflip_prompt.txt"
+deterministic: true
+max_gen_length: 64
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 16 - 0
tasks/cot/coinflip_prompt.txt

@@ -0,0 +1,16 @@
+Q: A coin is heads up. Ka flips the coin. Sherrie flips the coin. Is the coin still heads up?
+A: The coin was flipped by Ka and Sherrie. So the coin was flipped 2 times, which is an even number. The coin started heads up, so after an even number of flips, it will still be heads up. So the answer is yes.
+Q: A coin is heads up. Jamey flips the coin. Teressa flips the coin. Is the coin still heads up?
+A: The coin was flipped by Jamey and Teressa. So the coin was flipped 2 times, which is an even number. The coin started heads up, so after an even number of flips, it will still be heads up. So the answer is yes.
+Q: A coin is heads up. Maybelle flips the coin. Shalonda does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by Maybelle. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Millicent does not flip the coin. Conception flips the coin. Is the coin still heads up?
+A: The coin was flipped by Conception. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Sal flips the coin. Raymond does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by Sal. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Conception flips the coin. Kristian does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by Conception. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Inga does not flip the coin. Elanor does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by no one. So the coin was flipped O times. The coin started heads up, and it was not flipped, so it is still heads up. So the answer is yes.
+Q: A coin is heads up. Ryan flips the coin. Shaunda flips the coin. Is the coin still heads up?
+A: The coin was flipped by Ryan and Shaunda. So the coin was flipped 2 times, which is an even number. The coin started heads up, so after an even number of flips, it will still be heads up. So the answer is yes.

+ 1 - 0
tasks/cot/gsm8k.yaml

@@ -10,4 +10,5 @@ deterministic: true
 max_gen_length: 128
 use_task_mask: true
 save_prediction: true
+chain_of_thought: true
 micro_batch_size: 4

+ 1 - 1
tasks/cot/gsm8k_prompt.txt

@@ -13,4 +13,4 @@ A: There were originally 9 computers. For each of 4 days, 5 more computers were
 Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
 A: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.
 Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
-A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8
+A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.

+ 15 - 0
tasks/cot/lastletter.yaml

@@ -0,0 +1,15 @@
+name: 'lastletter'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'symbolic'
+file-pattern:
+  test: "lastletter.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/lastletter_prompt.txt"
+deterministic: true
+unidirectional: true
+max_gen_length: 64
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 8 - 0
tasks/cot/lastletter_prompt.txt

@@ -0,0 +1,8 @@
+Q: Take the last letters of the words in "Elon Musk" and concatenate them.
+A: The last letter of "Elon" is "n". The last letter of "Musk" is "k". Concatenating them is "nk". The answer is nk.
+Q: Take the last letters of the words in "Larry Page" and concatenate them.
+A: The last letter of "Larry" is "y". The last letter of "Page" is "e". Concatenating them is "ye". The answer is ye.
+Q: Take the last letters of the words in "Sergey Brin" and concatenate them.
+A: The last letter of "Sergey" is "y". The last letter of "Brin" is "n". Concatenating them is "yn". The answer is yn.
+Q: Take the last letters of the words in "Bill Gates" and concatenate them.
+A: The last letter of "Bill" is "l". The last letter of "Gates" is "s". Concatenating them is "ls". The answer is ls.

+ 14 - 0
tasks/cot/reverse.yaml

@@ -0,0 +1,14 @@
+name: 'reverse'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'symbolic'
+file-pattern:
+  test: "reverse_5.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/reverse_prompt.txt"
+deterministic: true
+max_gen_length: 128
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 16 - 0
tasks/cot/reverse_prompt.txt

@@ -0,0 +1,16 @@
+Q: Reverse the sequence "cigar, umbrella, key, gum, alarm".
+A: First is cigar. Second is umbrella. Third is key. Fourth is gum. Fifth is alarm. Now to reverse, change the order to: Fifth is alarm. Fourth is gum. Third is key. Second is umbrella. First is cigar. So the answer is "alarm, gum, key, umbrella, cigar"
+Q: Reverse the sequence "player, passport, umbrella, bottle, watch".
+A: First is player. Second is passport. Third is umbrella. Fourth is bottle. Fifth is watch. Now to reverse, change the order to: Fifth is watch. Fourth is bottle. Third is umbrella. Second is passport. First is player. So the answer is "watch, bottle, umbrella, passport, player"
+Q: Reverse the sequence "coin, postcard, case, pen, wallet".
+A: First is coin. Second is postcard. Third is case. Fourth is pen. Fifth is wallet. Now to reverse, change the order to: Fifth is wallet. Fourth is pen. Third is case. Second is postcard. First is coin. So the answer is "wallet, pen, case, postcard, coin".
+Q: Reverse the sequence "laptop, lipstick, pen, bin, clock".
+A: First is laptop. Second is lipstick. Third is pen. Fourth is bin. Fifth is clock. Now to reverse, change the order to: Fifth is clock. Fourth is bin. Third is pen. Second is lipstick. First is laptop. So the answer is "clock, bin, pen, lipstick, laptop"
+Q: Reverse the sequence "key, pen, screen, file, cigar".
+A: First is key. Second is pen. Third is screen. Fourth is file. Fifth is cigar. Now to reverse, change the order to: Fifth is cigar. Fourth is file. Third is screen. Second is pen. First is key. So the answer is "cigar, file, screen, pen, key".
+Q: Reverse the sequence "card, stamp, book, water, glasses"
+A: First is card. Second is stamp. Third is book. Fourth is water. Fifth is glasses. Now to reverse, change the order to: Fifth is glasses. Fourth is water. Third is book. Second is stamp. First is card. The answer is "glasses, water, book, stamp, card".
+Q: Reverse the sequence "clock, coin, bottle, head, postcard".
+A: First is clock. Second is coin. Third is bottle. Fourth is head. Fifth is postcard. Now to reverse, change the order to: Fifth is postcard. Fourth is head. Third is bottle. Second is coin. First is clock. So the answer is "postcard, head, bottle, coin, clock".
+Q: Reverse the sequence "battery, glasses, lighter, water, scissors".
+A: First is battery. Second is glasses. Third is lighter. Fourth is water. Fifth is scissors. Now to reverse, change the order to: Fifth is scissors. Fourth is water. Third is lighter. Second is glasses. First is battery. So the answer is "scissors, water, lighter, glasses, battery".

+ 16 - 0
tasks/cot/sports.yaml

@@ -0,0 +1,16 @@
+name: 'sports'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'commonsense'
+file-pattern:
+  test: "sports.json"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/sports_prompt.txt"
+deterministic: true
+unidirectional: true
+max_gen_length: 128
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+prompt_type: 'number'
+micro_batch_size: 4

+ 16 - 0
tasks/cot/sports_prompt.txt

@@ -0,0 +1,16 @@
+Q: Is the following sentence plausible? "Kyle Palmieri was called for slashing."
+A: Kyle Palmieri is a hockey player. Being called for slashing is part of hockey. So the answer is yes.
+Q: Is the following sentence plausible? "Joao Moutinho caught the screen pass in the NFC championship."
+A: Joao Moutinho is a soccer player. The NFC championship is part of American football, not soccer. So the answer is no.
+Q: Is the following sentence plausible? "Carson Wentz set the pick and roll."
+A: Carson Wentz is an American football player. Pick and roll is part of basketball, not football. So the answer is no.
+Q: Is the following sentence plausible? "Jonas Valanciunas beat the buzzer."
+A: Jonas Valanciunas is a basketball player. Beating the buzzer is part of basketball. So the answer is yes.
+Q: Is the following sentence plausible? "Jamel Murray was perfect from the line."
+A: Jamal Murray is a basketball player. Being perfect from the line is part of basketball. So the answer is yes.
+Q: Is the following sentence plausible? "Sam Darnold passed the puck."
+A: Sam Darnold is an American football player. Passing the puck is part of hockey, not American football. So the answer is no.
+Q: Is the following sentence plausible? "Draymond Green threw a touchdown."
+A: Draymond Green is a basketball player. Throwing a touchdown is part of football, not basketball. So the answer is no.
+Q: Is the following sentence plausible? "Malcolm Brogdon banked the shot in."
+A: Malcolm Brogdon is a basketball player. Banking the shot in is part of basketball. So the answer is yes.

+ 105 - 14
tasks/cot/task.py

@@ -11,6 +11,8 @@ from dataclasses import dataclass
 @dataclass
 class ChainOfThoughtConfig(GenerationTaskConfig):
     prompt_path: str = None
+    chain_of_thought: bool = True
+    prompt_type: str = None
 
 
 def read_examples(prompt_path):
@@ -32,19 +34,64 @@ def read_examples(prompt_path):
     return examples
 
 
-def build_prompt(examples):
+def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
     prompts = []
-    for item in examples:
+    for i, item in enumerate(examples):
         question, answer = item["question"], item["answer"]
-        prompts.append(f"Question: {question} Answer: {answer}")
-    prompt = " ".join(prompts)
+        if not chain_of_thought:
+            answer = extract_answer(answer, task_name)
+        if prompt_type == "number":
+            prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
+        else:
+            prompts.append(f"Question: {question} Answer: {answer}")
+    if prompt_type == "return":
+        prompt = " <n>".join(prompts)
+    else:
+        prompt = " ".join(prompts)
     return prompt
 
 
-def extract_answer(prediction, task_name):
-    if task_name == "gsm8k":
+def extract_answer(prediction, task_name, chain_of_thought=True):
+    if task_name.startswith("gsm8k"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )\d+'
+        else:
+            pattern = r'\d+'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = ""
+    elif task_name.startswith("sports") or task_name.startswith("coinflip"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )(yes|no)'
+        else:
+            pattern = r'yes|no'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = "no"
+    elif task_name.startswith("lastletter"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )[a-z]+'
+        else:
+            pattern = r'[a-z]+'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = ""
+    elif task_name.startswith("reverse"):
         prediction = prediction.lower()
-        match = re.search(r'(?<=the answer is )\d+', prediction)
+        if chain_of_thought:
+            pattern = r'(?<=the answer is ")[a-z|,| ]+'
+        else:
+            pattern = r'[a-z|,| ]+'
+        match = re.search(pattern, prediction)
         if match:
             answer = match.group(0)
         else:
@@ -55,18 +102,25 @@ def extract_answer(prediction, task_name):
 
 
 class ChainOfThoughtDataset(GenerationTaskDataset):
+    config: ChainOfThoughtConfig
 
     def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
         self.labeled_examples = read_examples(config.prompt_path)
-        self.labeled_prompt = build_prompt(self.labeled_examples)
+        self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
+                                           prompt_type=config.prompt_type)
         print_rank_0(self.labeled_prompt)
         self.printed_count = 0
         super().__init__(path, config)
+        print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
 
     def process_single_item(self, item, **kwargs):
-        question = item["question"]
-        targets = item["answer"].split("####")[1].strip()
-        text = self.labeled_prompt + f" Question: {question} Answer:"
+        question, targets = item["question"], item["targets"]
+        if self.config.prompt_type == "number":
+            text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
+        elif self.config.prompt_type == "return":
+            text = self.labeled_prompt + f" <n>Question: {question} Answer:"
+        else:
+            text = self.labeled_prompt + f" Question: {question} Answer:"
         text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
@@ -77,6 +131,34 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         return [{"text": text, "targets": targets, **kwargs}]
 
 
+class GSM8KDataset(ChainOfThoughtDataset):
+    def process_single_item(self, item, **kwargs):
+        item["targets"] = item["answer"].split("####")[1].strip()
+        return super().process_single_item(item, **kwargs)
+
+
+class SportsDataset(ChainOfThoughtDataset):
+    def process_single_file(self, path):
+        with open(path) as file:
+            dataset = json.load(file)
+        for item in dataset["examples"]:
+            sentence = item["input"]
+            item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
+            if item["target_scores"]["plausible"] == 1:
+                item["targets"] = "yes"
+            else:
+                item["targets"] = "no"
+            self.data.extend(self.process_single_item(item))
+
+
+class LastLetterDataset(ChainOfThoughtDataset):
+    def process_single_item(self, item, **kwargs):
+        first_name, last_name = item["first_name"], item["last_name"]
+        question = f'Take the last letters of the words in "{first_name} {last_name}" and concatenate them.'
+        item["question"] = question
+        return super().process_single_item(item, **kwargs)
+
+
 class ChainOfThoughtTask(GenerationTask):
     config: ChainOfThoughtConfig
 
@@ -94,19 +176,28 @@ class ChainOfThoughtTask(GenerationTask):
         assert len(predictions) == len(examples)
         for prediction, example in zip(predictions, examples):
             output = self.tokenizer.detokenize(prediction)
-            prediction = extract_answer(output, self.config.name).strip()
+            prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
             target = self.tokenizer.detokenize(example["targets"]).strip()
             count += prediction == target
         return count * 100.0 / num_predictions
 
     def build_dataset(self, relative_path, split):
-        return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
+        if self.config.name.startswith("gsm8k"):
+            return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("sports"):
+            return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("lastletter"):
+            return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
+            return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
+        else:
+            raise NotImplementedError
 
     def save_prediction_to_file(self, file, predictions, data):
         results = []
         for output, item in zip(predictions, data):
             output = self.tokenizer.detokenize(output)
-            prediction = extract_answer(output, self.config.name)
+            prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
             target = self.tokenizer.detokenize(item["targets"])
             results.append({"output": output, "prediction": prediction, "answer": target})
         file_name = file.split(".")[0]