2
0
Эх сурвалжийг харах

Implement generation task for GLM-10B

duzx16 2 жил өмнө
parent
commit
b030789021

+ 1 - 1
evaluation/configs.py

@@ -42,7 +42,7 @@ class MultiChoiceTaskConfig(BaseConfig):
 @dataclass
 class GenerationTaskConfig(BaseConfig):
     module = "evaluation.GenerationTask"
-    metrics: List[str] = field(default_factory=lambda: ["EM", "F1"])
+    metrics: List[str] = field(default_factory=lambda: [])
     sampling_strategy: str = "BaseStrategy"
     num_beams: int = 4
     length_penalty: float = 1.0

+ 58 - 3
evaluation/dataset.py

@@ -18,14 +18,15 @@ from .utils import get_tokenized_input
 
 
 def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
+    pad_length = max_seq_length - len(tokens)
     attention_mask = np.pad(
         attention_mask,
-        pad_width=((0, max_seq_length - len(tokens)),),
+        pad_width=((0, pad_length),),
         mode="constant",
         constant_values=0,
     )
-    tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
-    position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
+    tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))
+    position_ids = np.concatenate((position_ids, position_ids[..., -1:].repeat(pad_length, -1)), axis=-1)
     return tokens, position_ids, attention_mask
 
 
@@ -166,6 +167,60 @@ class GenerationTaskDataset(EvaluationDataset):
         )
 
 
+class SmallGenerationTaskDataset(GenerationTaskDataset):
+    config: GenerationTaskConfig
+
+    @staticmethod
+    def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
+        tokenizer = get_tokenizer()
+
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[gMASK]").Id if use_task_mask else tokenizer.get_command("[MASK]").Id
+        cls_id = tokenizer.get_command("ENC")
+        eos_id = tokenizer.get_command("eos")
+
+        token = np.array(text, dtype=np.int64)
+
+        blank_filling = mask_id in text
+        if blank_filling:
+            assert not unidirectional, "Unidirectional attention doesn't support blank filling"
+            assert not use_task_mask, "Unidirectional attention doesn't support task mask"
+            mask_position = text.index(mask_id) + 1
+            context_length = len(token) + 2
+            token = np.concatenate(([cls_id], token, [eos_id, sop_id]))
+        else:
+            if unidirectional:
+                mask_position = 1
+                context_length = 3
+                token = np.concatenate(([cls_id, mask_id, eos_id, sop_id], token))
+            else:
+                mask_position = len(token) + 1
+                context_length = len(token) + 3
+                token = np.concatenate(([cls_id], token, [mask_id, eos_id, sop_id]))
+        prefix_length = len(token) - context_length
+
+        position_id = [list(range(context_length)) + [mask_position] * prefix_length,
+                       [0] * context_length + list(range(1, prefix_length + 1))]
+        position_id = np.array(position_id, dtype=np.int64)
+
+        target_position_id = [[mask_position] * max_gen_length,
+                              list(range(prefix_length + 1, prefix_length + max_gen_length + 1))]
+        target_position_id = np.array(target_position_id, dtype=np.int64)
+
+        attention_mask = np.tril(np.ones((len(token), len(token)), dtype=np.int64))
+        if not unidirectional:
+            attention_mask[: len(token) - 1, : len(token) - 1] = 1
+
+        item = {
+            "token": token,
+            "position_id": position_id,
+            "target_position_id": target_position_id,
+            "attention_mask": attention_mask,
+            "context_length": context_length,
+        }
+        return item
+
+
 class MultiChoiceTaskDataset(EvaluationDataset):
     config: MultiChoiceTaskConfig
 

+ 2 - 2
evaluation/metrics.py

@@ -72,8 +72,8 @@ def qa_evaluate(predictions, examples, metric):
 
     score = 0.0
     for example, prediction in zip(examples, predictions):
-        ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
-        prediction = tokenizer.tokenizer.decode(prediction)
+        ground_truths = [tokenizer.detokenize(target) for target in example["targets"]]
+        prediction = tokenizer.detokenize(prediction)
         if ground_truths:
             score += metric_max_over_ground_truths(metric, prediction, ground_truths)
     score = 100.0 * score / len(predictions)

+ 2 - 1
evaluation/model.py

@@ -66,7 +66,8 @@ def batch_filling_sequence(
         tokens, mems = strategy.forward(logits, tokens, mems)
         if len(tokens.shape) == 3 and num_beams == 1:
             num_beams = tokens.shape[1]
-            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            tail_size = position_ids.shape[1:]
+            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, *tail_size).reshape(batch_size * num_beams, *tail_size)
             attention_mask_shape = attention_mask.shape[-3:]
             attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
                 batch_size * num_beams, *attention_mask_shape)

+ 2 - 2
evaluation/tasks.py

@@ -14,7 +14,7 @@ from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceT
 from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
-from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
+from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset, SmallGenerationTaskDataset
 from .utils import build_data_loader, gather_result, print_rank_0
 from .metrics import DEFAULT_METRICS
 
@@ -163,7 +163,7 @@ class GenerationTask(BaseTask, ABC):
         return GenerationTaskConfig
 
     def build_dataset(self, relative_path):
-        return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
+        return SmallGenerationTaskDataset(join(self.config.path, relative_path), self.config)
 
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
         super(GenerationTask, self).__init__(model, tokenizer, config)

+ 33 - 10
initialize.py

@@ -7,7 +7,8 @@ from quantization import quantize
 from SwissArmyTransformer import get_args, get_tokenizer
 from SwissArmyTransformer.arguments import initialize_distributed
 from SwissArmyTransformer.training import load_checkpoint
-from SwissArmyTransformer.model import GLM130B
+from SwissArmyTransformer.model import GLM130B, GLMModel
+from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin
 
 
 def add_bminf_args(parser):
@@ -31,6 +32,7 @@ def initialize(extra_args_provider):
     add_bminf_args(parser)
     add_quantization_args(parser)
     GLM130B.add_model_specific_args(parser)
+    GLMModel.add_model_specific_args(parser)
     extra_args_provider(parser)
     known, args_list = parser.parse_known_args()
     args = get_args(args_list)
@@ -40,11 +42,32 @@ def initialize(extra_args_provider):
     return args
 
 
-def initialize_model_and_tokenizer(args):
-    tokenizer = get_tokenizer(args)
+class SmallTokenizer:
+    def __init__(self, tokenizer):
+        self.tokenizer = tokenizer
+
+    def tokenize(self, text):
+        return self.tokenizer.EncodeAsIds(text).tokenization
+
+    def detokenize(self, ids):
+        return self.tokenizer.DecodeIds(ids)
+
+    def get_command(self, name):
+        map = {"[MASK]": "MASK", "[gMASK]": "gMASK", "[sMASK]": "sMASK"}
+        if name in map:
+            name = map[name]
+        return self.tokenizer.get_command(name).Id
 
+
+def initialize_model_and_tokenizer(args):
+    if args.tokenizer_type.startswith("glm_"):
+        tokenizer = SmallTokenizer(get_tokenizer(args))
+        tokenizer = get_tokenizer(args, outer_tokenizer=tokenizer)
+    else:
+        tokenizer = get_tokenizer(args)
     # Initialize model
-    model = GLM130B(args).half()
+    model = GLMModel(args).half()
+    model.add_mixin('cached-autoregressive', CachedAutoregressiveMixin())
 
     if args.from_quantized_checkpoint:
         assert args.quantization_bit_width is not None
@@ -77,12 +100,12 @@ def initialize_model_and_tokenizer(args):
     model.eval()
 
     # generate rotary embedding cache
-    with torch.no_grad():
-        _, *_ = model(
-            torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
-            torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
-            torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
-        )
+    # with torch.no_grad():
+    #     _, *_ = model(
+    #         torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
+    #         torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
+    #         torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
+    #     )
 
     torch.distributed.barrier()