Kaynağa Gözat

Add reverse task

duzx16 2 yıl önce
ebeveyn
işleme
fdae7439d5
3 değiştirilmiş dosya ile 43 ekleme ve 1 silme
  1. 14 0
      tasks/cot/reverse.yaml
  2. 16 0
      tasks/cot/reverse_prompt.txt
  3. 13 1
      tasks/cot/task.py

+ 14 - 0
tasks/cot/reverse.yaml

@@ -0,0 +1,14 @@
+name: 'reverse'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'symbolic'
+file-pattern:
+  test: "reverse_5.jsonl"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/reverse_prompt.txt"
+deterministic: true
+max_gen_length: 128
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 16 - 0
tasks/cot/reverse_prompt.txt

@@ -0,0 +1,16 @@
+Q: Reverse the sequence "cigar, umbrella, key, gum, alarm".
+A: First is cigar. Second is umbrella. Third is key. Fourth is gum. Fifth is alarm. Now to reverse, change the order to: Fifth is alarm. Fourth is gum. Third is key. Second is umbrella. First is cigar. So the answer is "alarm, gum, key, umbrella, cigar"
+Q: Reverse the sequence "player, passport, umbrella, bottle, watch".
+A: First is player. Second is passport. Third is umbrella. Fourth is bottle. Fifth is watch. Now to reverse, change the order to: Fifth is watch. Fourth is bottle. Third is umbrella. Second is passport. First is player. So the answer is "watch, bottle, umbrella, passport, player"
+Q: Reverse the sequence "coin, postcard, case, pen, wallet".
+A: First is coin. Second is postcard. Third is case. Fourth is pen. Fifth is wallet. Now to reverse, change the order to: Fifth is wallet. Fourth is pen. Third is case. Second is postcard. First is coin. So the answer is "wallet, pen, case, postcard, coin".
+Q: Reverse the sequence "laptop, lipstick, pen, bin, clock".
+A: First is laptop. Second is lipstick. Third is pen. Fourth is bin. Fifth is clock. Now to reverse, change the order to: Fifth is clock. Fourth is bin. Third is pen. Second is lipstick. First is laptop. So the answer is "clock, bin, pen, lipstick, laptop"
+Q: Reverse the sequence "key, pen, screen, file, cigar".
+A: First is key. Second is pen. Third is screen. Fourth is file. Fifth is cigar. Now to reverse, change the order to: Fifth is cigar. Fourth is file. Third is screen. Second is pen. First is key. So the answer is "cigar, file, screen, pen, key".
+Q: Reverse the sequence "card, stamp, book, water, glasses"
+A: First is card. Second is stamp. Third is book. Fourth is water. Fifth is glasses. Now to reverse, change the order to: Fifth is glasses. Fourth is water. Third is book. Second is stamp. First is card. The answer is "glasses, water, book, stamp, card".
+Q: Reverse the sequence "clock, coin, bottle, head, postcard".
+A: First is clock. Second is coin. Third is bottle. Fourth is head. Fifth is postcard. Now to reverse, change the order to: Fifth is postcard. Fourth is head. Third is bottle. Second is coin. First is clock. So the answer is "postcard, head, bottle, coin, clock".
+Q: Reverse the sequence "battery, glasses, lighter, water, scissors".
+A: First is battery. Second is glasses. Third is lighter. Fourth is water. Fifth is scissors. Now to reverse, change the order to: Fifth is scissors. Fourth is water. Third is lighter. Second is glasses. First is battery. So the answer is "scissors, water, lighter, glasses, battery".

+ 13 - 1
tasks/cot/task.py

@@ -85,6 +85,17 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
             answer = match.group(0)
         else:
             answer = ""
+    elif task_name.startswith("reverse"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is ")[a-z|,| ]+'
+        else:
+            pattern = r'[a-z|,| ]+'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = ""
     else:
         raise NotImplementedError(task_name)
     return answer
@@ -100,6 +111,7 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         print_rank_0(self.labeled_prompt)
         self.printed_count = 0
         super().__init__(path, config)
+        print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
 
     def process_single_item(self, item, **kwargs):
         question, targets = item["question"], item["targets"]
@@ -176,7 +188,7 @@ class ChainOfThoughtTask(GenerationTask):
             return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
         elif self.config.name.startswith("lastletter"):
             return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
-        elif self.config.name.startswith("coinflip"):
+        elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
             return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
         else:
             raise NotImplementedError