Browse Source

Merge branch 'cot'

Sengxian 2 years ago
parent
commit
a22632be22

+ 2 - 0
evaluation/configs.py

@@ -26,6 +26,7 @@ class BaseConfig(YAMLWizard):
     unidirectional: bool = False  # Whether to use unidirectional attention
     max_seq_length: int = 2048  # Max sequence length
     file_pattern: str | Dict[str, str] = "**/*.json*"  # Organize data file in groups
+    save_prediction: bool = False
 
     micro_batch_size: int = 1  # 'gen' task only support mbs = 1 for now
 
@@ -49,6 +50,7 @@ class GenerationTaskConfig(BaseConfig):
     no_repeat_ngram_size: int = 3
     min_gen_length: int = 0
     max_gen_length: int = 128
+    end_tokens: List[str] = field(default_factory=lambda: [])
 
 
 @dataclass

+ 50 - 51
evaluation/dataset.py

@@ -12,7 +12,6 @@ from itertools import accumulate
 from bisect import bisect_right
 
 from SwissArmyTransformer import get_tokenizer
-from SwissArmyTransformer.mpu import get_model_parallel_rank
 
 from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
 from .utils import get_tokenized_input
@@ -58,7 +57,6 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
     def has_collate_fn(self) -> bool:
         return False
 
-    @staticmethod
     def collate_fn(self, samples):
         return None
 
@@ -66,10 +64,10 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
         with open(os.path.join(path), "r", encoding="utf-8") as file:
             for line in file:
                 item = json.loads(line)
-                self.data.append(self.process_single_item(item))
+                self.data.extend(self.process_single_item(item))
 
     @abstractmethod
-    def process_single_item(self, item) -> dict:
+    def process_single_item(self, item, **kwargs) -> List[dict]:
         pass
 
     def __len__(self):
@@ -79,19 +77,18 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
 class GenerationTaskDataset(EvaluationDataset):
     config: GenerationTaskConfig
 
-    def process_single_item(self, item):
+    def process_single_item(self, item, **kwargs):
         text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
             text_length = self.config.max_seq_length - self.config.max_gen_length - 2
             text = text[len(text) - text_length : len(text)]
-        return {"text": text, "targets": targets}
+        return [{"text": text, "targets": targets, **kwargs}]
 
     @property
     def has_collate_fn(self) -> bool:
         return True
 
-    @staticmethod
-    def collate_fn(samples):
+    def collate_fn(self, samples):
         TILE = 32
         length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
 
@@ -105,8 +102,8 @@ class GenerationTaskDataset(EvaluationDataset):
             token_batch.append(token)
             position_id_batch.append(position_id)
             attention_mask_batch.append(attention_mask)
-            context_length_batch.append(sample["context_length"])
-            target_position_id_batch.append(sample["target_position_id"])
+            context_length_batch.append(sample['context_length'])
+            target_position_id_batch.append(sample['target_position_id'])
         return {
             "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
             "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
@@ -141,7 +138,7 @@ class GenerationTaskDataset(EvaluationDataset):
         position_id = np.arange(0, context_length, dtype=np.int64)
         target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
-            position_id[context_length - 1 :] = mask_position
+            position_id[context_length - 1:] = mask_position
             target_position_id[:] = mask_position
 
         attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
@@ -159,12 +156,13 @@ class GenerationTaskDataset(EvaluationDataset):
 
     def __getitem__(self, idx):
         item = self.data[idx]
-        return self.build_generation_sample(
+        sample = self.build_generation_sample(
             item["text"],
             max_gen_length=self.config.max_gen_length,
             use_task_mask=self.config.use_task_mask,
             unidirectional=self.config.unidirectional,
         )
+        return sample
 
 
 class MultiChoiceTaskDataset(EvaluationDataset):
@@ -178,15 +176,13 @@ class MultiChoiceTaskDataset(EvaluationDataset):
     def has_collate_fn(self) -> bool:
         return True
 
-    @staticmethod
-    def collate_fn(samples):
+    def collate_fn(self, samples):
         TILE = 32
         length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
 
         token_batch, position_id_batch, attention_mask_batch = [], [], []
         choices_batch, choice_target_ids_batch = [], []
 
-        is_single_token = True
         for sample in samples:
             token, position_id, attention_mask = pad_batch(
                 sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
@@ -196,8 +192,6 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             attention_mask_batch.append(attention_mask)
             choices_batch.append(sample["choices"])
             choice_target_ids_batch.append(sample["choice_target_ids"])
-            if isinstance(sample["choice_target_ids"], list):
-                is_single_token = False
 
         return {
             "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
@@ -205,10 +199,10 @@ class MultiChoiceTaskDataset(EvaluationDataset):
             "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
             "choices": choices_batch,
             "choice_target_ids": choice_target_ids_batch,
-            "is_single_token": is_single_token,
+            "is_single_token": self.is_single_token,
         }
 
-    def process_single_item(self, item):
+    def process_single_item(self, item, **kwargs):
         text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
 
         tgt_seq_length = sum([len(choice) for choice in choices])
@@ -228,16 +222,16 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         if tgt_seq_length != 1:
             self.is_single_token = False
 
-        return {
+        return [{
             "text": text,
             "choices": choices,
             "label": label,
-        }
+            **kwargs
+        }]
 
     @staticmethod
-    def build_multiple_choice_sample(
-        text, choices, is_single_token, unified_multitask_encoding=False, use_task_mask=False
-    ):
+    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
+                                     unidirectional=False, use_task_mask=False):
         tokenizer = get_tokenizer()
 
         sop_id = tokenizer.get_command("sop")
@@ -250,48 +244,51 @@ class MultiChoiceTaskDataset(EvaluationDataset):
 
         blank_filling = mask_id in text
         if not blank_filling:
-            mask_position = len(token)
-            token = np.concatenate((token, [mask_id]))
-            target = np.concatenate((target, [mask_id]))
-            position_id = np.concatenate((position_id, [mask_position]))
+            if unidirectional:
+                assert use_task_mask
+                token = np.concatenate(([mask_id, sop_id], token[:-1]))
+                target = np.concatenate(([mask_id, sop_id], target[:-1]))
+                position_id = np.arange(len(token), dtype=np.int64)
+                mask_position = len(token)
+            else:
+                mask_position = len(token)
+                token = np.concatenate((token, [mask_id]))
+                target = np.concatenate((target, [mask_id]))
+                position_id = np.concatenate((position_id, [mask_position]))
         else:
+            assert not unidirectional, "Unidirectional attention doesn't support blank filling"
+            assert not use_task_mask, "Unidirectional attention doesn't support task mask"
             mask_position = text.index(mask_id)
 
         division = len(token)
         attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
+        if unidirectional:
+            attention_mask[0] = np.tril(attention_mask[0])
 
         for choice in choices:
-            if len(choice) == 0:
-                if get_model_parallel_rank() == 0:
-                    print("Empty choice found")
-                choice = [0]
-            if use_task_mask == False:
-                position_id = np.concatenate(
-                    (
-                        position_id,
-                        [mask_position] * len(choice)
-                        if blank_filling or not unified_multitask_encoding
-                        else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
-                    )
-                )
-            else:
-                position_id = np.concatenate(
-                    (
-                        position_id,
-                        np.arange(division, division + len(choice), dtype=np.int64),
-                    )
+            if not choice:
+                choice = [tokenizer.get_command('eop')]
+            position_id = np.concatenate(
+                (
+                    position_id,
+                    [mask_position] * len(choice)
+                    if (blank_filling or not unified_multitask_encoding) and not use_task_mask
+                    else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
                 )
-
+            )
             choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
             attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
-            token = np.concatenate((token, [sop_id], choice[:-1]))
+            if unidirectional:
+                token = np.concatenate((token, [text[-1]], choice[:-1]))
+            else:
+                token = np.concatenate((token, [sop_id], choice[:-1]))
             target = np.concatenate((target, choice))
 
             if is_single_token:
                 break
 
         attention_mask = block_diag(*attention_mask)
-        attention_mask[: len(token), :division] = 1
+        attention_mask[division:, :division] = 1
 
         if is_single_token:
             choices = np.array(choices, dtype=np.int64).squeeze().tolist()
@@ -307,13 +304,15 @@ class MultiChoiceTaskDataset(EvaluationDataset):
 
     def __getitem__(self, idx):
         item = self.data[idx]
-        return self.build_multiple_choice_sample(
+        sample = self.build_multiple_choice_sample(
             item["text"],
             item["choices"],
             is_single_token=self.is_single_token,
             unified_multitask_encoding=self.config.use_multitask_encoding,
+            unidirectional=self.config.unidirectional,
             use_task_mask=self.config.use_task_mask,
         )
+        return sample
 
 
 class LanguageModelTaskDataset(EvaluationDataset):

+ 11 - 0
evaluation/tasks.py

@@ -41,6 +41,10 @@ class BaseTask(ABC):
 
         self.file_groups = self.get_file_groups()
         self.verbose = dist.get_rank() == 0
+        self.save_prediction = config.save_prediction
+
+    def save_prediction_to_file(self, file, prediction, data):
+        pass
 
     def get_file_groups(self):
         pattern_group = {}
@@ -84,9 +88,12 @@ class BaseTask(ABC):
                     for _, batch in enumerate(dataloader):
                         prediction.append(self.predict_single_batch(batch))
 
+
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
                 result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
                 result_dict_group[file] = (result_dict, len(dataset))
+                if torch.distributed.get_rank() == 0 and self.save_prediction:
+                    self.save_prediction_to_file(file, prediction, dataset.data)
 
                 if self.verbose:
                     self.report_single_metrics(file, result_dict)
@@ -169,6 +176,10 @@ class GenerationTask(BaseTask, ABC):
         super(GenerationTask, self).__init__(model, tokenizer, config)
 
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
+        if self.config.end_tokens:
+            for token in self.config.end_tokens:
+                end_tokens.append(self.tokenizer.tokenize(token)[-1])
+            print_rank_0(f"End tokens {end_tokens}")
         if self.config.sampling_strategy == "BaseStrategy":
             self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
                                          end_tokens=end_tokens)

+ 1 - 1
scripts/evaluate.sh

@@ -6,7 +6,7 @@ main_dir=$(dirname $script_dir)
 
 source "${main_dir}/configs/model_glm_130b.sh"
 
-DATA_PATH="<your evaluation dataset base directory>"
+DATA_PATH="/zhangpai21/workspace/zxdu"
 
 ARGS="${main_dir}/evaluate.py \
        --mode inference \

+ 14 - 0
tasks/cot/coinflip.yaml

@@ -0,0 +1,14 @@
+name: 'coinflip'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'symbolic'
+file-pattern:
+  test: "coinflip.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/coinflip_prompt.txt"
+deterministic: true
+max_gen_length: 64
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 16 - 0
tasks/cot/coinflip_prompt.txt

@@ -0,0 +1,16 @@
+Q: A coin is heads up. Ka flips the coin. Sherrie flips the coin. Is the coin still heads up?
+A: The coin was flipped by Ka and Sherrie. So the coin was flipped 2 times, which is an even number. The coin started heads up, so after an even number of flips, it will still be heads up. So the answer is yes.
+Q: A coin is heads up. Jamey flips the coin. Teressa flips the coin. Is the coin still heads up?
+A: The coin was flipped by Jamey and Teressa. So the coin was flipped 2 times, which is an even number. The coin started heads up, so after an even number of flips, it will still be heads up. So the answer is yes.
+Q: A coin is heads up. Maybelle flips the coin. Shalonda does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by Maybelle. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Millicent does not flip the coin. Conception flips the coin. Is the coin still heads up?
+A: The coin was flipped by Conception. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Sal flips the coin. Raymond does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by Sal. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Conception flips the coin. Kristian does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by Conception. So the coin was flipped 1 time, which is an odd number. The coin started heads up, so after an odd number of flips, it will be tails up. So the answer is no.
+Q: A coin is heads up. Inga does not flip the coin. Elanor does not flip the coin. Is the coin still heads up?
+A: The coin was flipped by no one. So the coin was flipped O times. The coin started heads up, and it was not flipped, so it is still heads up. So the answer is yes.
+Q: A coin is heads up. Ryan flips the coin. Shaunda flips the coin. Is the coin still heads up?
+A: The coin was flipped by Ryan and Shaunda. So the coin was flipped 2 times, which is an even number. The coin started heads up, so after an even number of flips, it will still be heads up. So the answer is yes.

+ 14 - 0
tasks/cot/gsm8k.yaml

@@ -0,0 +1,14 @@
+name: 'gsm8k'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'grade_school_math/data'
+file-pattern:
+  test: "test.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/gsm8k_prompt.txt"
+deterministic: true
+max_gen_length: 128
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 16 - 0
tasks/cot/gsm8k_prompt.txt

@@ -0,0 +1,16 @@
+Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
+A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6
+Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
+A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.
+Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
+A: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.
+Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
+A: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.
+Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
+A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.
+Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
+A: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.
+Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
+A: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.
+Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
+A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.

+ 15 - 0
tasks/cot/lastletter.yaml

@@ -0,0 +1,15 @@
+name: 'lastletter'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'symbolic'
+file-pattern:
+  test: "lastletter.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/lastletter_prompt.txt"
+deterministic: true
+unidirectional: true
+max_gen_length: 64
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 8 - 0
tasks/cot/lastletter_prompt.txt

@@ -0,0 +1,8 @@
+Q: Take the last letters of the words in "Elon Musk" and concatenate them.
+A: The last letter of "Elon" is "n". The last letter of "Musk" is "k". Concatenating them is "nk". The answer is nk.
+Q: Take the last letters of the words in "Larry Page" and concatenate them.
+A: The last letter of "Larry" is "y". The last letter of "Page" is "e". Concatenating them is "ye". The answer is ye.
+Q: Take the last letters of the words in "Sergey Brin" and concatenate them.
+A: The last letter of "Sergey" is "y". The last letter of "Brin" is "n". Concatenating them is "yn". The answer is yn.
+Q: Take the last letters of the words in "Bill Gates" and concatenate them.
+A: The last letter of "Bill" is "l". The last letter of "Gates" is "s". Concatenating them is "ls". The answer is ls.

+ 14 - 0
tasks/cot/reverse.yaml

@@ -0,0 +1,14 @@
+name: 'reverse'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'symbolic'
+file-pattern:
+  test: "reverse_5.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/reverse_prompt.txt"
+deterministic: true
+max_gen_length: 128
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 16 - 0
tasks/cot/reverse_prompt.txt

@@ -0,0 +1,16 @@
+Q: Reverse the sequence "cigar, umbrella, key, gum, alarm".
+A: First is cigar. Second is umbrella. Third is key. Fourth is gum. Fifth is alarm. Now to reverse, change the order to: Fifth is alarm. Fourth is gum. Third is key. Second is umbrella. First is cigar. So the answer is "alarm, gum, key, umbrella, cigar"
+Q: Reverse the sequence "player, passport, umbrella, bottle, watch".
+A: First is player. Second is passport. Third is umbrella. Fourth is bottle. Fifth is watch. Now to reverse, change the order to: Fifth is watch. Fourth is bottle. Third is umbrella. Second is passport. First is player. So the answer is "watch, bottle, umbrella, passport, player"
+Q: Reverse the sequence "coin, postcard, case, pen, wallet".
+A: First is coin. Second is postcard. Third is case. Fourth is pen. Fifth is wallet. Now to reverse, change the order to: Fifth is wallet. Fourth is pen. Third is case. Second is postcard. First is coin. So the answer is "wallet, pen, case, postcard, coin".
+Q: Reverse the sequence "laptop, lipstick, pen, bin, clock".
+A: First is laptop. Second is lipstick. Third is pen. Fourth is bin. Fifth is clock. Now to reverse, change the order to: Fifth is clock. Fourth is bin. Third is pen. Second is lipstick. First is laptop. So the answer is "clock, bin, pen, lipstick, laptop"
+Q: Reverse the sequence "key, pen, screen, file, cigar".
+A: First is key. Second is pen. Third is screen. Fourth is file. Fifth is cigar. Now to reverse, change the order to: Fifth is cigar. Fourth is file. Third is screen. Second is pen. First is key. So the answer is "cigar, file, screen, pen, key".
+Q: Reverse the sequence "card, stamp, book, water, glasses"
+A: First is card. Second is stamp. Third is book. Fourth is water. Fifth is glasses. Now to reverse, change the order to: Fifth is glasses. Fourth is water. Third is book. Second is stamp. First is card. The answer is "glasses, water, book, stamp, card".
+Q: Reverse the sequence "clock, coin, bottle, head, postcard".
+A: First is clock. Second is coin. Third is bottle. Fourth is head. Fifth is postcard. Now to reverse, change the order to: Fifth is postcard. Fourth is head. Third is bottle. Second is coin. First is clock. So the answer is "postcard, head, bottle, coin, clock".
+Q: Reverse the sequence "battery, glasses, lighter, water, scissors".
+A: First is battery. Second is glasses. Third is lighter. Fourth is water. Fifth is scissors. Now to reverse, change the order to: Fifth is scissors. Fourth is water. Third is lighter. Second is glasses. First is battery. So the answer is "scissors, water, lighter, glasses, battery".

+ 16 - 0
tasks/cot/sports.yaml

@@ -0,0 +1,16 @@
+name: 'sports'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'commonsense'
+file-pattern:
+  test: "sports.json"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/sports_prompt.txt"
+deterministic: true
+unidirectional: true
+max_gen_length: 128
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+prompt_type: 'number'
+micro_batch_size: 4

+ 16 - 0
tasks/cot/sports_prompt.txt

@@ -0,0 +1,16 @@
+Q: Is the following sentence plausible? "Kyle Palmieri was called for slashing."
+A: Kyle Palmieri is a hockey player. Being called for slashing is part of hockey. So the answer is yes.
+Q: Is the following sentence plausible? "Joao Moutinho caught the screen pass in the NFC championship."
+A: Joao Moutinho is a soccer player. The NFC championship is part of American football, not soccer. So the answer is no.
+Q: Is the following sentence plausible? "Carson Wentz set the pick and roll."
+A: Carson Wentz is an American football player. Pick and roll is part of basketball, not football. So the answer is no.
+Q: Is the following sentence plausible? "Jonas Valanciunas beat the buzzer."
+A: Jonas Valanciunas is a basketball player. Beating the buzzer is part of basketball. So the answer is yes.
+Q: Is the following sentence plausible? "Jamel Murray was perfect from the line."
+A: Jamal Murray is a basketball player. Being perfect from the line is part of basketball. So the answer is yes.
+Q: Is the following sentence plausible? "Sam Darnold passed the puck."
+A: Sam Darnold is an American football player. Passing the puck is part of hockey, not American football. So the answer is no.
+Q: Is the following sentence plausible? "Draymond Green threw a touchdown."
+A: Draymond Green is a basketball player. Throwing a touchdown is part of football, not basketball. So the answer is no.
+Q: Is the following sentence plausible? "Malcolm Brogdon banked the shot in."
+A: Malcolm Brogdon is a basketball player. Banking the shot in is part of basketball. So the answer is yes.

+ 209 - 0
tasks/cot/task.py

@@ -0,0 +1,209 @@
+import os
+import json
+import re
+from typing import Union, List, Dict, Callable
+from datetime import datetime
+from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
+from evaluation.utils import print_rank_0
+from dataclasses import dataclass
+
+
+@dataclass
+class ChainOfThoughtConfig(GenerationTaskConfig):
+    prompt_path: str = None
+    chain_of_thought: bool = True
+    prompt_type: str = None
+
+
+def read_examples(prompt_path):
+    examples = []
+    item = {"question": None, "answer": None}
+    with open(prompt_path) as file:
+        for line in file:
+            line = line.strip()
+            if line.startswith("Q:"):
+                question = line[3:]
+                item["question"] = question
+            elif line.startswith("A:"):
+                answer = line[3:]
+                item["answer"] = answer
+                examples.append(item)
+                item = {"question": None, "answer": None}
+            else:
+                raise NotImplementedError
+    return examples
+
+
+def build_prompt(examples, task_name, chain_of_thought=True, prompt_type=None):
+    prompts = []
+    for i, item in enumerate(examples):
+        question, answer = item["question"], item["answer"]
+        if not chain_of_thought:
+            answer = extract_answer(answer, task_name)
+        if prompt_type == "number":
+            prompts.append(f"{i+1}. Question: {question} Answer: {answer}")
+        else:
+            prompts.append(f"Question: {question} Answer: {answer}")
+    if prompt_type == "return":
+        prompt = " <n>".join(prompts)
+    else:
+        prompt = " ".join(prompts)
+    return prompt
+
+
+def extract_answer(prediction, task_name, chain_of_thought=True):
+    if task_name.startswith("gsm8k"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )\d+'
+        else:
+            pattern = r'\d+'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = ""
+    elif task_name.startswith("sports") or task_name.startswith("coinflip"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )(yes|no)'
+        else:
+            pattern = r'yes|no'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = "no"
+    elif task_name.startswith("lastletter"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )[a-z]+'
+        else:
+            pattern = r'[a-z]+'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = ""
+    elif task_name.startswith("reverse"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is ")[a-z|,| ]+'
+        else:
+            pattern = r'[a-z|,| ]+'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = ""
+    else:
+        raise NotImplementedError(task_name)
+    return answer
+
+
+class ChainOfThoughtDataset(GenerationTaskDataset):
+    config: ChainOfThoughtConfig
+
+    def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
+        self.labeled_examples = read_examples(config.prompt_path)
+        self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought,
+                                           prompt_type=config.prompt_type)
+        # print_rank_0(self.labeled_prompt)
+        self.printed_count = 0
+        super().__init__(path, config)
+        # print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
+
+    def process_single_item(self, item, **kwargs):
+        question, targets = item["question"], item["targets"]
+        if self.config.prompt_type == "number":
+            text = self.labeled_prompt + f" {len(self.labeled_examples) + 1}. Question: {question} Answer:"
+        elif self.config.prompt_type == "return":
+            text = self.labeled_prompt + f" <n>Question: {question} Answer:"
+        else:
+            text = self.labeled_prompt + f" Question: {question} Answer:"
+        text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
+        if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
+            text_length = self.config.max_seq_length - self.config.max_gen_length - 2
+            text = text[len(text) - text_length: len(text)]
+        # if self.printed_count < 3:
+        #     print_rank_0(self.tokenizer.detokenize(text))
+        #     self.printed_count += 1
+        return [{"text": text, "targets": targets, **kwargs}]
+
+
+class GSM8KDataset(ChainOfThoughtDataset):
+    def process_single_item(self, item, **kwargs):
+        item["targets"] = item["answer"].split("####")[1].strip()
+        return super().process_single_item(item, **kwargs)
+
+
+class SportsDataset(ChainOfThoughtDataset):
+    def process_single_file(self, path):
+        with open(path) as file:
+            dataset = json.load(file)
+        for item in dataset["examples"]:
+            sentence = item["input"]
+            item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
+            if item["target_scores"]["plausible"] == 1:
+                item["targets"] = "yes"
+            else:
+                item["targets"] = "no"
+            self.data.extend(self.process_single_item(item))
+
+
+class LastLetterDataset(ChainOfThoughtDataset):
+    def process_single_item(self, item, **kwargs):
+        first_name, last_name = item["first_name"], item["last_name"]
+        question = f'Take the last letters of the words in "{first_name} {last_name}" and concatenate them.'
+        item["question"] = question
+        return super().process_single_item(item, **kwargs)
+
+
+class ChainOfThoughtTask(GenerationTask):
+    config: ChainOfThoughtConfig
+
+    @classmethod
+    def config_class(cls):
+        return ChainOfThoughtConfig
+
+    @property
+    def metrics(self) -> Dict[str, Callable]:
+        return {'acuracy': self.extracted_accuracy_metric}
+
+    def extracted_accuracy_metric(self, predictions, examples):
+        count = 0
+        num_predictions = max(len(predictions), 1)
+        assert len(predictions) == len(examples)
+        for prediction, example in zip(predictions, examples):
+            output = self.tokenizer.detokenize(prediction)
+            prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
+            target = self.tokenizer.detokenize(example["targets"]).strip()
+            count += prediction == target
+        return count * 100.0 / num_predictions
+
+    def build_dataset(self, relative_path):
+        if self.config.name.startswith("gsm8k"):
+            return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("sports"):
+            return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("lastletter"):
+            return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
+            return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
+        else:
+            raise NotImplementedError
+
+    def save_prediction_to_file(self, file, predictions, data):
+        results = []
+        for output, item in zip(predictions, data):
+            output = self.tokenizer.detokenize(output)
+            prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
+            target = self.tokenizer.detokenize(item["targets"])
+            results.append({"output": output, "prediction": prediction, "answer": target})
+        file_name = file.split(".")[0]
+        if not os.path.exists("outputs"):
+            os.mkdir("outputs")
+        with open("outputs/" + self.config.name + "_" + datetime.now().strftime(
+                '%m-%d-%H-%M_') + file_name + ".json", "w") as output:
+            for result in results:
+                output.write(json.dumps(result) + "\n")