소스 검색

Support GLM's 2D positional encoding

Sengxian 2 년 전
부모
커밋
b31b6a9c14
11개의 변경된 파일234개의 추가작업 그리고 291개의 파일을 삭제
  1. 1 1
      evaluate.py
  2. 20 110
      evaluation/dataset.py
  3. 169 2
      evaluation/model.py
  4. 3 3
      evaluation/tasks.py
  5. 15 3
      generate.py
  6. 7 1
      initialize.py
  7. 4 157
      server.py
  8. 8 7
      tasks/cot/task.py
  9. 3 3
      tasks/ethnic/crows-pair/tasks.py
  10. 3 3
      tasks/ethnic/stereoset/tasks.py
  11. 1 1
      tasks/language-modeling/pile.py

+ 1 - 1
evaluate.py

@@ -56,7 +56,7 @@ def main():
     print_rank_0(f"> Successfully load {len(task_classes)} task{'s' if len(task_classes) > 1 else ''}")
 
     model, tokenizer = initialize_model_and_tokenizer(args)
-    model = ModelForEvaluation(model)
+    model = ModelForEvaluation(model, args.position_encoding_2d)
 
     start = time.time()
     evaluate_all_tasks(args.data_path, model, tokenizer, args.task, task_classes)

+ 20 - 110
evaluation/dataset.py

@@ -15,17 +15,21 @@ from SwissArmyTransformer import get_tokenizer
 
 from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
 from .utils import get_tokenized_input
+from .model import ModelForEvaluation
 
 
 def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
+    pad_length = max_seq_length - len(tokens)
     attention_mask = np.pad(
         attention_mask,
-        pad_width=((0, max_seq_length - len(tokens)),),
+        pad_width=((0, pad_length),),
         mode="constant",
         constant_values=0,
     )
-    tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
-    position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
+    tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))
+    position_ids = np.concatenate(
+        (position_ids, np.zeros_like(position_ids[..., -1:], dtype=np.int64).repeat(pad_length, -1)), axis=-1
+    )
     return tokens, position_ids, attention_mask
 
 
@@ -39,8 +43,9 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
     If [MASK] not in context, will append [MASK] after text
     """
 
-    def __init__(self, path: Union[str, List[str]], config: BaseConfig):
+    def __init__(self, path: Union[str, List[str]], model: ModelForEvaluation, config: BaseConfig):
         self.path = path if isinstance(path, list) else [path]
+        self.model = model
         self.config = config
         self.max_seq_length = self.config.max_seq_length
         self.dtype = np.int64
@@ -170,9 +175,9 @@ class GenerationTaskDataset(EvaluationDataset):
 class MultiChoiceTaskDataset(EvaluationDataset):
     config: MultiChoiceTaskConfig
 
-    def __init__(self, path, config: MultiChoiceTaskConfig):
+    def __init__(self, path: Union[str, List[str]], model: ModelForEvaluation, config: BaseConfig):
         self.is_single_token = True  # set to False later in process_single_item func
-        super().__init__(path, config)
+        super().__init__(path, model, config)
 
     @property
     def has_collate_fn(self) -> bool:
@@ -226,88 +231,9 @@ class MultiChoiceTaskDataset(EvaluationDataset):
 
         return [{"text": text, "choices": choices, "label": label, **kwargs}]
 
-    @staticmethod
-    def build_multiple_choice_sample(
-        text,
-        choices,
-        is_single_token,
-        unified_multitask_encoding=False,
-        unidirectional=False,
-        use_task_mask=False,
-    ):
-        tokenizer = get_tokenizer()
-
-        sop_id = tokenizer.get_command("sop")
-        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
-
-        token = np.array(text, dtype=np.int64)
-        target = np.array(text, dtype=np.int64)
-        position_id = np.arange(len(text), dtype=np.int64)
-        choice_target_id = []
-
-        blank_filling = mask_id in text
-        if not blank_filling:
-            if unidirectional:
-                assert use_task_mask
-                token = np.concatenate(([mask_id, sop_id], token[:-1]))
-                target = np.concatenate(([mask_id, sop_id], target[:-1]))
-                position_id = np.arange(len(token), dtype=np.int64)
-                mask_position = len(token)
-            else:
-                mask_position = len(token)
-                token = np.concatenate((token, [mask_id]))
-                target = np.concatenate((target, [mask_id]))
-                position_id = np.concatenate((position_id, [mask_position]))
-        else:
-            assert not unidirectional, "Unidirectional attention doesn't support blank filling"
-            assert not use_task_mask, "Unidirectional attention doesn't support task mask"
-            mask_position = text.index(mask_id)
-
-        division = len(token)
-        attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
-        if unidirectional:
-            attention_mask[0] = np.tril(attention_mask[0])
-
-        for choice in choices:
-            if not choice:
-                choice = [tokenizer.get_command("eop")]
-            position_id = np.concatenate(
-                (
-                    position_id,
-                    [mask_position] * len(choice)
-                    if (blank_filling or not unified_multitask_encoding) and not use_task_mask
-                    else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
-                )
-            )
-            choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
-            attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
-            if unidirectional:
-                token = np.concatenate((token, [text[-1]], choice[:-1]))
-            else:
-                token = np.concatenate((token, [sop_id], choice[:-1]))
-            target = np.concatenate((target, choice))
-
-            if is_single_token:
-                break
-
-        attention_mask = block_diag(*attention_mask)
-        attention_mask[division:, :division] = 1
-
-        if is_single_token:
-            choices = np.array(choices, dtype=np.int64).squeeze().tolist()
-
-        item = {
-            "token": token,
-            "position_id": position_id,
-            "attention_mask": attention_mask,
-            "choices": choices,
-            "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
-        }
-        return item
-
     def __getitem__(self, idx):
         item = self.data[idx]
-        sample = self.build_multiple_choice_sample(
+        sample = self.model.build_multiple_choice_sample(
             item["text"],
             item["choices"],
             is_single_token=self.is_single_token,
@@ -358,27 +284,11 @@ class LanguageModelTaskDataset(EvaluationDataset):
         end_idx = start_idx + self.config.max_seq_length - 1  # for additional [gMASK]
         tokens = self.data[document_idx]["raw_text"][start_idx:end_idx]
 
-        mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
-        sop_id = self.tokenizer.get_command("sop")
-
-        if idx == 0 or self.config.unidirectional:
-            prompt, text = [], tokens
-        else:
-            prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
-            prompt, text = tokens[:prompt_length], tokens[prompt_length:]
-
-        seq_length = len(prompt) + len(text) + 1
-        attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
-        attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
-
-        gen_length = min(len(text), self.config.generation_length)
-        return {
-            "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
-            "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
-            "position_ids": np.arange(0, seq_length, dtype=np.int64),
-            "attention_mask": attention_mask < 0.5,
-            "loss_masks": np.array(
-                [0] * (seq_length - gen_length) + [1] * gen_length,
-                dtype=np.int64,
-            ),
-        }
+        return self.model.build_language_model_sample(
+            tokens,
+            is_first_segment=idx == 0,
+            max_seq_length=self.config.max_seq_length,
+            generation_length=self.config.generation_length,
+            unidirectional=self.config.unidirectional,
+            use_gmask=self.config.use_task_mask,
+        )

+ 169 - 2
evaluation/model.py

@@ -1,9 +1,12 @@
+import numpy as np
 import torch
 
 from typing import List, Union
+from scipy.linalg import block_diag
 
 from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
 from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
+from SwissArmyTransformer import get_tokenizer
 
 
 def batch_filling_sequence(
@@ -71,7 +74,9 @@ def batch_filling_sequence(
         if len(tokens.shape) == 3 and num_beams == 1:
             num_beams = tokens.shape[1]
             position_ids = (
-                position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+                position_ids.unsqueeze(1)
+                .expand((batch_size, num_beams) + position_ids.shape[1:])
+                .reshape((batch_size * num_beams,) + position_ids.shape[1:])
             )
             attention_mask_shape = attention_mask.shape[-3:]
             attention_mask = (
@@ -85,10 +90,11 @@ def batch_filling_sequence(
 
 
 class ModelForEvaluation(torch.nn.Module):
-    def __init__(self, model):
+    def __init__(self, model, position_encoding_2d):
         super().__init__()
 
         self.model = model
+        self.position_encoding_2d = position_encoding_2d
         self.device = next(self.model.parameters()).device
 
     @staticmethod
@@ -99,6 +105,115 @@ class ModelForEvaluation(torch.nn.Module):
             batch["attention_mask"].to(device=device).bool().unsqueeze(1),
         )
 
+    def build_multiple_choice_sample(
+        self,
+        text,
+        choices,
+        is_single_token,
+        unified_multitask_encoding=False,
+        unidirectional=False,
+        use_task_mask=False,
+    ):
+        tokenizer = get_tokenizer()
+
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
+
+        token = np.array(text, dtype=np.int64)
+        target = np.array(text, dtype=np.int64)
+        position_id = np.arange(len(text), dtype=np.int64)
+        block_position_id = np.zeros(len(text), dtype=np.int64)
+        choice_target_id = []
+
+        blank_filling = mask_id in text
+        if not blank_filling:
+            if unidirectional:
+                assert use_task_mask, "Unidirectional attention only support gMASK"
+                token = np.concatenate(([mask_id, sop_id], token[:-1]))
+                target = np.concatenate(([mask_id, sop_id], target[:-1]))
+                position_id = np.zeros(len(token), dtype=np.int64)
+                if self.position_encoding_2d:
+                    block_position_id = np.arange(len(token), dtype=np.int64)
+                mask_position = len(token)
+            else:
+                mask_position = len(token)
+                token = np.concatenate((token, [mask_id]))
+                target = np.concatenate((target, [mask_id]))
+                position_id = np.arange(len(token), dtype=np.int64)
+                if self.position_encoding_2d:
+                    block_position_id = np.zeros(len(token), dtype=np.int64)
+        else:
+            assert not unidirectional, "Unidirectional attention doesn't support blank filling"
+            assert not use_task_mask, "Blank filling only support MASK"
+            mask_position = text.index(mask_id)
+
+        division = len(token)
+        attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
+        if unidirectional:
+            attention_mask[0] = np.tril(attention_mask[0])
+
+        for choice in choices:
+            if not choice:
+                choice = [tokenizer.get_command("eop")]
+
+            target = np.concatenate((target, choice))
+            choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
+            attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
+
+            if unidirectional:
+                if self.position_encoding_2d:
+                    position_id = np.concatenate((position_id, [0] * len(choice)))
+                    block_position_id = np.concatenate(
+                        (block_position_id, np.arange(mask_position, mask_position + len(choice), dtype=np.int64))
+                    )
+                else:
+                    position_id = np.concatenate(
+                        (
+                            position_id,
+                            np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
+                        )
+                    )
+
+                token = np.concatenate((token, [text[-1]], choice[:-1]))
+            else:
+                if self.position_encoding_2d:
+                    position_id = np.concatenate((position_id, [mask_position] * len(choice)))
+                    block_position_id = np.concatenate(
+                        (block_position_id, np.arange(1, 1 + len(choice), dtype=np.int64))
+                    )
+                else:
+                    position_id = np.concatenate(
+                        (
+                            position_id,
+                            [mask_position] * len(choice)
+                            if (blank_filling or not unified_multitask_encoding) and not use_task_mask
+                            else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
+                        )
+                    )
+
+                token = np.concatenate((token, [sop_id], choice[:-1]))
+
+            if is_single_token:
+                break
+
+        attention_mask = block_diag(*attention_mask)
+        attention_mask[division:, :division] = 1
+
+        if is_single_token:
+            choices = np.array(choices, dtype=np.int64).squeeze().tolist()
+
+        if self.position_encoding_2d:
+            position_id = np.stack((position_id, block_position_id), axis=0)
+
+        item = {
+            "token": token,
+            "position_id": position_id,
+            "attention_mask": attention_mask,
+            "choices": choices,
+            "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
+        }
+        return item
+
     def cond_log_prob(self, batch) -> List[List[float]]:
         """
         @return: Conditional log probability of each option
@@ -115,6 +230,12 @@ class ModelForEvaluation(torch.nn.Module):
         # output: [b, sq, vocab]
         log_probs = []
 
+        # if torch.distributed.get_rank() == 0:
+        #     import pdb
+        #
+        #     pdb.set_trace()
+        # torch.distributed.barrier()
+
         if is_single_token:  # Single token
             for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
                 log_probs.append(logits[choice_target_ids[0], choices].tolist())
@@ -184,6 +305,52 @@ class ModelForEvaluation(torch.nn.Module):
                 output_targets.append(output_target)
         return output_targets
 
+    def build_language_model_sample(
+        self,
+        tokens: List[int],
+        is_first_segment: bool,
+        max_seq_length: int,
+        generation_length: int,
+        unidirectional: bool,
+        use_gmask: bool,
+    ):
+        tokenizer = get_tokenizer()
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[gMASK]") if use_gmask else tokenizer.get_command("[MASK]")
+
+        if is_first_segment or unidirectional:
+            prompt, text = [], tokens
+        else:
+            prompt_length = max_seq_length - 1 - generation_length
+            prompt, text = tokens[:prompt_length], tokens[prompt_length:]
+
+        seq_length = len(prompt) + len(text) + 1
+        attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
+        attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
+
+        gen_length = min(len(text), generation_length)
+
+        position_id = np.arange(0, seq_length, dtype=np.int64)
+        if self.position_encoding_2d:
+            position_id = np.concatenate(
+                (np.arange(0, seq_length - gen_length, dtype=np.int64), [seq_length - gen_length - 1] * gen_length)
+            )
+            block_position_id = np.concatenate(
+                ([0] * (seq_length - gen_length - 1), np.arange(0, gen_length + 1, dtype=np.int64))
+            )
+            position_id = np.stack((position_id, block_position_id), axis=0)
+
+        return {
+            "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
+            "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
+            "position_ids": position_id,
+            "attention_mask": attention_mask < 0.5,
+            "loss_masks": np.array(
+                [0] * (seq_length - gen_length) + [1] * gen_length,
+                dtype=np.int64,
+            ),
+        }
+
     def calculate_loss(self, batch) -> List[float]:
         tokens, position_ids, attention_mask = self.process_data(batch, self.device)
         targets, loss_masks = (

+ 3 - 3
evaluation/tasks.py

@@ -170,7 +170,7 @@ class GenerationTask(BaseTask, ABC):
         return GenerationTaskConfig
 
     def build_dataset(self, relative_path):
-        return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
+        return GenerationTaskDataset(join(self.config.path, relative_path), self.model, self.config)
 
     def save_prediction_to_file(self, file, prediction, data):
         filename = os.path.join("outputs", self.config.name, f"{file}.predict")
@@ -218,7 +218,7 @@ class MultiChoiceTask(BaseTask, ABC):
         return MultiChoiceTaskConfig
 
     def build_dataset(self, relative_path):
-        return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
+        return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.model, self.config)
 
     def predict_single_batch(self, batch) -> List[int]:
         log_probs = self.model.cond_log_prob(batch)
@@ -233,7 +233,7 @@ class LanguageModelTask(BaseTask, ABC):
         return LanguageModelTaskConfig
 
     def build_dataset(self, relative_path):
-        return LanguageModelTaskDataset(join(self.config.path, relative_path), self.config)
+        return LanguageModelTaskDataset(join(self.config.path, relative_path), self.model, self.config)
 
     def predict_single_batch(self, batch) -> List[float]:
         return self.model.calculate_loss(batch)

+ 15 - 3
generate.py

@@ -30,7 +30,7 @@ def isEnglish(s):
         return True
 
 
-def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
+def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False, position_encoding_2d=False):
     context_length = seq.shape[1]
     tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
     attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
@@ -39,9 +39,20 @@ def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
     attention_mask.unsqueeze_(1)
     attention_mask = (attention_mask < 0.5).bool()
 
-    position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
-    if not gmask:
+    if position_encoding_2d:
+        position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
         position_ids[context_length - 1 :] = mask_position
+        block_position_ids = torch.cat(
+            (
+                torch.zeros(context_length - 2, dtype=torch.long, device=tokens.device),
+                torch.arange(tokens.shape[-1] - (context_length - 2), dtype=torch.long, device=tokens.device),
+            )
+        )
+        position_ids = torch.vstack((position_ids, block_position_ids))
+    else:
+        position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
+        if not gmask:
+            position_ids[context_length - 1 :] = mask_position
 
     position_ids = position_ids.unsqueeze(0)
 
@@ -115,6 +126,7 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
                 mask_position=mask_position,
                 max_gen_length=args.out_seq_length,
                 gmask=use_gmask,
+                position_encoding_2d=args.position_encoding_2d,
             ),
         )
         if isinstance(output, torch.Tensor):  # different strategies

+ 7 - 1
initialize.py

@@ -100,7 +100,13 @@ def initialize_model_and_tokenizer(args):
     with torch.no_grad():
         _, *_ = model(
             torch.ones(1, args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64),
-            torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1),
+            torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64)
+            .view(1, 1, -1)
+            .repeat(1, 2, 1)
+            if args.position_encoding_2d
+            else torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64).view(
+                1, -1
+            ),
             torch.randn(
                 1,
                 1,

+ 4 - 157
server.py

@@ -1,164 +1,11 @@
-import os
+import time
 import torch
-import stat
-import re
-
-from functools import partial
-from typing import List, Tuple
-
-from SwissArmyTransformer import mpu
-from evaluation.model import batch_filling_sequence
-from generation import BeamSearchStrategy, BaseStrategy
-from SwissArmyTransformer.generation.utils import timed_name
-from initialize import initialize, initialize_model_and_tokenizer
-
 import torch.distributed as dist
-import time
-
 import gradio as gr
 
-
-def add_generation_specific_args(parser):
-    parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
-    parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
-    parser.add_argument(
-        "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
-    )
-
-
-def isEnglish(s):
-    try:
-        s.encode(encoding="utf-8").decode("ascii")
-    except UnicodeDecodeError:
-        return False
-    else:
-        return True
-
-
-def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
-    context_length = seq.shape[1]
-    tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
-    attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
-    attention_mask.tril_()
-    attention_mask[..., : context_length - 1] = 1
-    attention_mask.unsqueeze_(1)
-    attention_mask = (attention_mask < 0.5).bool()
-
-    position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
-    if not gmask:
-        position_ids[context_length - 1 :] = mask_position
-
-    position_ids = position_ids.unsqueeze(0)
-
-    return tokens, attention_mask, position_ids
-
-
-def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
-    # add MASK
-    generation_mask = "[gMASK]"
-    if "[MASK]" in raw_text:
-        generation_mask = "[MASK]"
-    elif "[sMASK]" in raw_text:
-        generation_mask = "[sMASK]"
-    use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
-
-    mask_pattern = r"\[[sg]?MASK\]"
-    text_list = re.split(mask_pattern, raw_text)
-    pattern_list = re.compile(mask_pattern).findall(raw_text)
-    seq = []
-    for i in range(len(pattern_list)):
-        pattern = pattern_list[i]
-        sub_text = text_list[i]
-        seq.extend(tokenizer.tokenize(sub_text))
-        seq.append(tokenizer.get_command(pattern))
-
-    seq.extend(tokenizer.tokenize(text_list[-1]))
-
-    if "MASK]" not in raw_text:
-        seq += [tokenizer.get_command(generation_mask)]
-        raw_text += " " + generation_mask
-    if not raw_text.endswith("MASK]"):
-        seq = seq + [tokenizer.get_command("eos")]
-    if mpu.get_model_parallel_rank() == 0:
-        print("\nInput: {}\n".format(raw_text))
-    if len(seq) > args.max_sequence_length:
-        raise ValueError("text too long.")
-
-    # generation
-    is_english = isEnglish(raw_text)
-    output_list = [seq]
-    num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
-    last_pos, answers, answers_with_style, blanks = (
-        [0] * num_output,
-        ["" for _ in range(num_output)],
-        ["" for _ in range(num_output)],
-        [[] for _ in range(num_output)],
-    )
-
-    # continually detect the first mark position
-    while True:
-        seq = output_list[0]
-        # detect mask position
-        mask_token = tokenizer.get_command(generation_mask)
-        if mask_token not in seq:
-            break
-        mask_position = seq.index(mask_token)
-
-        output_list = []
-
-        input_seq = torch.cuda.LongTensor(
-            [seq + [tokenizer.get_command("sop")]],
-            device=args.device,
-        )
-        output, _ = batch_filling_sequence(
-            model,
-            input_seq,
-            torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
-            strategy=strategy,
-            get_masks_and_position_ids=partial(
-                get_masks_and_position_ids,
-                mask_position=mask_position,
-                max_gen_length=args.out_seq_length,
-                gmask=use_gmask,
-            ),
-        )
-        if isinstance(output, torch.Tensor):  # different strategies
-            output = output.tolist()
-        output = output[0]  # batch_size = 1
-        output_list.extend(output)
-
-        # clip -1s and fill back generated things into seq
-        for i in range(len(output_list)):
-            output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
-            try:
-                unfinished = output.index(-1)
-            except ValueError:
-                unfinished = len(output)
-            if output[unfinished - 1] in strategy.end_tokens:
-                unfinished -= 1
-            bog = output.index(tokenizer.get_command("sop"))
-
-            prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
-            blank = tokenizer.detokenize(output[bog + 1 : unfinished])
-            answers_with_style[i] += (
-                prefix
-                + (" " if is_english else "")
-                + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
-                + blank
-                + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
-                + (" " if is_english else "")
-            )
-            blanks[i].append(blank)
-            last_pos[i] = mask_position + unfinished - (bog + 1)
-            output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
-
-    for i, output in enumerate(output_list):
-        if output[-1] == tokenizer.get_command("eos"):
-            output = output[:-1]
-        answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
-        answers[i] = tokenizer.detokenize(output)
-
-    return answers, answers_with_style, blanks
+from generation import BeamSearchStrategy, BaseStrategy
+from initialize import initialize, initialize_model_and_tokenizer
+from generate import add_generation_specific_args, fill_blanks
 
 
 def generate_continually(func, raw_text):

+ 8 - 7
tasks/cot/task.py

@@ -3,6 +3,7 @@ import json
 import re
 from typing import Union, List, Dict, Callable
 from datetime import datetime
+from evaluation.model import ModelForEvaluation
 from evaluation.tasks import GenerationTask, GenerationTaskDataset, GenerationTaskConfig
 from evaluation.utils import print_rank_0
 from dataclasses import dataclass
@@ -116,14 +117,14 @@ def extract_answer(prediction, task_name, chain_of_thought=True):
 class ChainOfThoughtDataset(GenerationTaskDataset):
     config: ChainOfThoughtConfig
 
-    def __init__(self, path: Union[str, List[str]], config: ChainOfThoughtConfig):
+    def __init__(self, path: Union[str, List[str]], model: ModelForEvaluation, config: ChainOfThoughtConfig):
         self.labeled_examples = read_examples(config.prompt_path)
         self.labeled_prompt = build_prompt(
             self.labeled_examples, config.name, chain_of_thought=config.chain_of_thought, prompt_type=config.prompt_type
         )
         # print_rank_0(self.labeled_prompt)
         self.printed_count = 0
-        super().__init__(path, config)
+        super().__init__(path, model, config)
         # print_rank_0(len(self.tokenizer.tokenize(self.labeled_prompt)))
 
     def process_single_item(self, item, **kwargs):
@@ -209,15 +210,15 @@ class ChainOfThoughtTask(GenerationTask):
 
     def build_dataset(self, relative_path):
         if self.config.name.startswith("gsm8k"):
-            return GSM8KDataset(os.path.join(self.config.path, relative_path), self.config)
+            return GSM8KDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
         elif self.config.name.startswith("sports"):
-            return SportsDataset(os.path.join(self.config.path, relative_path), self.config)
+            return SportsDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
         elif self.config.name.startswith("lastletter"):
-            return LastLetterDataset(os.path.join(self.config.path, relative_path), self.config)
+            return LastLetterDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
         elif self.config.name.startswith("coinflip") or self.config.name.startswith("reverse"):
-            return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.config)
+            return ChainOfThoughtDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
         elif self.config.name.startswith("date"):
-            return DateDataset(os.path.join(self.config.path, relative_path), self.config)
+            return DateDataset(os.path.join(self.config.path, relative_path), self.model, self.config)
         else:
             raise NotImplementedError
 

+ 3 - 3
tasks/ethnic/crows-pair/tasks.py

@@ -19,7 +19,7 @@ class CrowsPairTask(MultiChoiceTask, ABC):
     config: MultiChoiceTaskConfig
 
     def build_dataset(self, relative_path):
-        return CrowsPairDataset(join(self.config.path, relative_path), self.config)
+        return CrowsPairDataset(join(self.config.path, relative_path), self.model, self.config)
 
     def predict_single_batch(self, batch) -> List[int]:
         log_probs = self.model.cond_log_prob(batch)
@@ -69,10 +69,10 @@ class CrowsPairDataset(MultiChoiceTaskDataset):
 
     config: MultiChoiceTaskConfig
 
-    def __init__(self, path, config: MultiChoiceTaskConfig):
+    def __init__(self, path, model, config: MultiChoiceTaskConfig):
         self.is_single_token = True  # set to False later in process_single_item func
         self.eval_data = []
-        super().__init__(path, config)
+        super().__init__(path, model, config)
 
     def process_single_item(self, item):
         text, choices, label = (

+ 3 - 3
tasks/ethnic/stereoset/tasks.py

@@ -20,7 +20,7 @@ class StereoSetTask(MultiChoiceTask, ABC):
     config: MultiChoiceTaskConfig
 
     def build_dataset(self, relative_path):
-        return StereoSetDataset(join(self.config.path, relative_path), self.config)
+        return StereoSetDataset(join(self.config.path, relative_path), self.model, self.config)
 
     def predict_single_batch(self, batch) -> List[int]:
         log_probs = self.model.cond_log_prob(batch)
@@ -84,10 +84,10 @@ class StereoSetTask(MultiChoiceTask, ABC):
 class StereoSetDataset(MultiChoiceTaskDataset):
     config: MultiChoiceTaskConfig
 
-    def __init__(self, path, config: MultiChoiceTaskConfig):
+    def __init__(self, path, model, config: MultiChoiceTaskConfig):
         self.is_single_token = True  # set to False later in process_single_item func
         self.eval_data = []
-        super().__init__(path, config)
+        super().__init__(path, model, config)
 
     def process_single_item(self, item):
         text, choices, label = (

+ 1 - 1
tasks/language-modeling/pile.py

@@ -33,7 +33,7 @@ class Pile(LanguageModelTask):
         return {"BPB": calculate_bpb_score}
 
     def build_dataset(self, relative_path):
-        return PileDataset(join(self.config.path, relative_path), self.config)
+        return PileDataset(join(self.config.path, relative_path), self.model, self.config)
 
     def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
         pass