Browse Source

Implement sports understanding

duzx16 2 years ago
parent
commit
2decee7b30
5 changed files with 83 additions and 11 deletions
  1. 0 1
      generation/strategies.py
  2. 1 0
      tasks/cot/gsm8k.yaml
  3. 14 0
      tasks/cot/sports.yaml
  4. 16 0
      tasks/cot/sports_prompt.txt
  5. 52 10
      tasks/cot/task.py

+ 0 - 1
generation/strategies.py

@@ -16,7 +16,6 @@ class BaseStrategy:
             end_tokens = []
             end_tokens = []
         self.end_tokens = end_tokens
         self.end_tokens = end_tokens
         self.deterministic = deterministic
         self.deterministic = deterministic
-        print(self.deterministic)
         self._is_done = np.zeros(self.batch_size, dtype=np.bool)
         self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
 
     @property
     @property

+ 1 - 0
tasks/cot/gsm8k.yaml

@@ -10,4 +10,5 @@ deterministic: true
 max_gen_length: 128
 max_gen_length: 128
 use_task_mask: true
 use_task_mask: true
 save_prediction: true
 save_prediction: true
+chain_of_thought: true
 micro_batch_size: 4
 micro_batch_size: 4

+ 14 - 0
tasks/cot/sports.yaml

@@ -0,0 +1,14 @@
+name: 'sports'
+type: 'gen'
+module: "tasks.cot.task.ChainOfThoughtTask"
+path: 'commonsense'
+file-pattern:
+  test: "sports.json"
+sampling_strategy: "BaseStrategy"
+prompt_path: "tasks/cot/sports_prompt.txt"
+deterministic: true
+max_gen_length: 128
+use_task_mask: true
+save_prediction: true
+chain_of_thought: true
+micro_batch_size: 4

+ 16 - 0
tasks/cot/sports_prompt.txt

@@ -0,0 +1,16 @@
+Q: Is the following sentence plausible? "Kyle Palmier was called for slashing."
+A: Kyle Palmier is a hockey player. Being called for slashing is part of hockey. So the answer is yes.
+Q: Is the following sentence plausible? "Joao Moutinho caught the screen pass in the NFC championship."
+A: Joao Moutinho is a soccer player. The NFC championship is part of American football, not soccer. So the answer is no.
+Q: Is the following sentence plausible? "Carson Wentz set the pick and roll."
+A: Carson Wentz is an American football player. Pick and roll is part of basketball, not football. So the answer is no.
+Q: Is the following sentence plausible? "Jonas Valanciunas beat the buzzer."
+A: Jonas Valanciunas is a basketball player. Beating the buzzer is part of basketball. So the answer is yes.
+Q: Is the following sentence plausible? "Jamel Murray was perfect from the line."
+A: Jamal Murray is a basketball player. Being perfect from the line is part of basketball. So the answer is yes.
+Q: Is the following sentence plausible? "Sam Darnold passed the puck."
+A: Sam Darnold is an American football player. Passing the puck is part of hockey, not American football. So the answer is no.
+Q: Is the following sentence plausible? "Draymond Green threw a touchdown."
+A: Draymond Green is a basketball player. Throwing a touchdown is part of football, not basketball. So the answer is no.
+Q: Is the following sentence plausible? "Malcolm Brogdon banked the shot in."
+A: Malcolm Brogdon is a basketball player. Banking the shot in is part of basketball. So the answer is yes.

+ 52 - 10
tasks/cot/task.py

@@ -11,6 +11,7 @@ from dataclasses import dataclass
 @dataclass
 @dataclass
 class ChainOfThoughtConfig(GenerationTaskConfig):
 class ChainOfThoughtConfig(GenerationTaskConfig):
     prompt_path: str = None
     prompt_path: str = None
+    chain_of_thought: bool = True
 
 
 
 
 def read_examples(prompt_path):
 def read_examples(prompt_path):
@@ -32,23 +33,40 @@ def read_examples(prompt_path):
     return examples
     return examples
 
 
 
 
-def build_prompt(examples):
+def build_prompt(examples, task_name, chain_of_thought=True):
     prompts = []
     prompts = []
     for item in examples:
     for item in examples:
         question, answer = item["question"], item["answer"]
         question, answer = item["question"], item["answer"]
+        if not chain_of_thought:
+            answer = extract_answer(answer, task_name)
         prompts.append(f"Question: {question} Answer: {answer}")
         prompts.append(f"Question: {question} Answer: {answer}")
     prompt = " ".join(prompts)
     prompt = " ".join(prompts)
     return prompt
     return prompt
 
 
 
 
-def extract_answer(prediction, task_name):
-    if task_name == "gsm8k":
+def extract_answer(prediction, task_name, chain_of_thought=True):
+    if task_name.startswith("gsm8k"):
         prediction = prediction.lower()
         prediction = prediction.lower()
-        match = re.search(r'(?<=the answer is )\d+', prediction)
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )\d+'
+        else:
+            pattern = r'\d+'
+        match = re.search(pattern, prediction)
         if match:
         if match:
             answer = match.group(0)
             answer = match.group(0)
         else:
         else:
             answer = ""
             answer = ""
+    elif task_name.startswith("sports"):
+        prediction = prediction.lower()
+        if chain_of_thought:
+            pattern = r'(?<=the answer is )(yes|no)'
+        else:
+            pattern = r'yes|no'
+        match = re.search(pattern, prediction)
+        if match:
+            answer = match.group(0)
+        else:
+            answer = "no"
     else:
     else:
         raise NotImplementedError(task_name)
         raise NotImplementedError(task_name)
     return answer
     return answer
@@ -58,14 +76,13 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
 
 
     def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
     def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
         self.labeled_examples = read_examples(config.prompt_path)
         self.labeled_examples = read_examples(config.prompt_path)
-        self.labeled_prompt = build_prompt(self.labeled_examples)
+        self.labeled_prompt = build_prompt(self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought)
         print_rank_0(self.labeled_prompt)
         print_rank_0(self.labeled_prompt)
         self.printed_count = 0
         self.printed_count = 0
         super().__init__(path, config)
         super().__init__(path, config)
 
 
     def process_single_item(self, item, **kwargs):
     def process_single_item(self, item, **kwargs):
-        question = item["question"]
-        targets = item["answer"].split("####")[1].strip()
+        question, targets = item["question"], item["targets"]
         text = self.labeled_prompt + f" Question: {question} Answer:"
         text = self.labeled_prompt + f" Question: {question} Answer:"
         text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
         text, targets = self.tokenizer.tokenize(text), self.tokenizer.tokenize(targets)
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
         if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
@@ -77,6 +94,26 @@ class ChainOfThoughtDataset(GenerationTaskDataset):
         return [{"text": text, "targets": targets, **kwargs}]
         return [{"text": text, "targets": targets, **kwargs}]
 
 
 
 
+class GSM8KDataset(ChainOfThoughtDataset):
+    def process_single_item(self, item, **kwargs):
+        item["targets"] = item["answer"].split("####")[1].strip()
+        return super().process_single_item(item)
+
+
+class SportsDataset(ChainOfThoughtDataset):
+    def process_single_file(self, path):
+        with open(path) as file:
+            dataset = json.load(file)
+        for item in dataset["examples"]:
+            sentence = item["input"]
+            item["question"] = f'Is the following sentence plausible? \"{sentence}.\"'
+            if item["target_scores"]["plausible"] == 1:
+                item["targets"] = "yes"
+            else:
+                item["targets"] = "no"
+            self.data.extend(self.process_single_item(item))
+
+
 class ChainOfThoughtTask(GenerationTask):
 class ChainOfThoughtTask(GenerationTask):
     config: ChainOfThoughtConfig
     config: ChainOfThoughtConfig
 
 
@@ -94,19 +131,24 @@ class ChainOfThoughtTask(GenerationTask):
         assert len(predictions) == len(examples)
         assert len(predictions) == len(examples)
         for prediction, example in zip(predictions, examples):
         for prediction, example in zip(predictions, examples):
             output = self.tokenizer.detokenize(prediction)
             output = self.tokenizer.detokenize(prediction)
-            prediction = extract_answer(output, self.config.name).strip()
+            prediction = extract_answer(output, self.config.name, self.config.chain_of_thought).strip()
             target = self.tokenizer.detokenize(example["targets"]).strip()
             target = self.tokenizer.detokenize(example["targets"]).strip()
             count += prediction == target
             count += prediction == target
         return count * 100.0 / num_predictions
         return count * 100.0 / num_predictions
 
 
     def build_dataset(self, relative_path, split):
     def build_dataset(self, relative_path, split):
-        return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
+        if self.config.name.startswith("gsm8k"):
+            return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
+        elif self.config.name.startswith("sports"):
+            return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
+        else:
+            raise NotImplementedError
 
 
     def save_prediction_to_file(self, file, predictions, data):
     def save_prediction_to_file(self, file, predictions, data):
         results = []
         results = []
         for output, item in zip(predictions, data):
         for output, item in zip(predictions, data):
             output = self.tokenizer.detokenize(output)
             output = self.tokenizer.detokenize(output)
-            prediction = extract_answer(output, self.config.name)
+            prediction = extract_answer(output, self.config.name, self.config.chain_of_thought)
             target = self.tokenizer.detokenize(item["targets"])
             target = self.tokenizer.detokenize(item["targets"])
             results.append({"output": output, "prediction": prediction, "answer": target})
             results.append({"output": output, "prediction": prediction, "answer": target})
         file_name = file.split(".")[0]
         file_name = file.split(".")[0]