瀏覽代碼

Add 8-bit quantization

Sengxian 3 年之前
父節點
當前提交
e10b098020

+ 1 - 1
configs/model_glm_130b.sh

@@ -1,5 +1,5 @@
 MODEL_TYPE="glm-130b"
-CHECKPOINT_PATH="/thudm/workspace/hanyu/SwissArmyTransformer/data/ckpt/iter_0049300"
+CHECKPOINT_PATH="<your checkpoint path>"
 MP_SIZE=8
 MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
             --num-layers 70 \

+ 16 - 0
configs/model_glm_130b_int4.sh

@@ -0,0 +1,16 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="<your checkpoint path>"
+MP_SIZE=4
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --quantization-bit-width 4 \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16"

+ 16 - 0
configs/model_glm_130b_int8.sh

@@ -0,0 +1,16 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="<your checkpoint path>"
+MP_SIZE=8
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --quantization-bit-width 8 \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16"

+ 23 - 6
initialize.py

@@ -2,6 +2,8 @@ import argparse
 import torch
 import time
 
+from quantization import quantize
+
 from SwissArmyTransformer import get_args, get_tokenizer
 from SwissArmyTransformer.arguments import initialize_distributed
 from SwissArmyTransformer.training import load_checkpoint
@@ -17,9 +19,16 @@ def add_bminf_args(parser):
     return parser
 
 
+def add_quantization_args(parser):
+    group = parser.add_argument_group("Quantization")
+
+    group.add_argument("--quantization-bit-width", type=int, default=None)
+
+
 def initialize(extra_args_provider):
     parser = argparse.ArgumentParser(add_help=False)
     add_bminf_args(parser)
+    add_quantization_args(parser)
     GLM130B.add_model_specific_args(parser)
     extra_args_provider(parser)
     known, args_list = parser.parse_known_args()
@@ -33,21 +42,29 @@ def initialize(extra_args_provider):
 def initialize_model_and_tokenizer(args):
     tokenizer = get_tokenizer(args)
 
+    # Initialize model
     model = GLM130B(args).half()
+
+    # Load checkpoint
+    torch.distributed.barrier()
+    start = time.time()
+    load_checkpoint(model, args)
+    torch.distributed.barrier()
+    if torch.distributed.get_rank() == 0:
+        print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
+
     if args.bminf:
         import bminf
 
         with torch.cuda.device(args.device):
             model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
     else:
+        if args.quantization_bit_width is not None:
+            # Quantize model before moving to GPU
+            model = quantize(model, args.quantization_bit_width)
         model = model.to(args.device)
 
-    torch.distributed.barrier()
-    start = time.time()
-    load_checkpoint(model, args)
-    torch.distributed.barrier()
-    if torch.distributed.get_rank() == 0:
-        print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
+    torch.cuda.empty_cache()
     model.eval()
 
     # generate rotary embedding cache

+ 60 - 0
quantization/__init__.py

@@ -0,0 +1,60 @@
+import torch
+
+from .layers import QuantizedColumnParallelLinear
+from .layers import QuantizedRowParallelLinear
+
+
+def quantize(model, bit_width):
+    """Replace fp16 linear with quantized linear"""
+
+    for layer in model.transformer.layers:
+        layer.attention.query_key_value = QuantizedColumnParallelLinear(
+            bit_width=bit_width,
+            weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
+            input_size=layer.attention.query_key_value.input_size,
+            output_size=layer.attention.query_key_value.output_size,
+            bias=True,
+            gather_output=False,
+            params_dtype=torch.half,
+            name="query_key_value",
+            skip_init=True,
+            device=layer.attention.query_key_value.weight.device,
+        )
+        layer.attention.dense = QuantizedRowParallelLinear(
+            bit_width=bit_width,
+            weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
+            input_size=layer.attention.dense.input_size,
+            output_size=layer.attention.dense.output_size,
+            bias=True,
+            input_is_parallel=True,
+            params_dtype=torch.half,
+            name="dense",
+            skip_init=True,
+            device=layer.attention.dense.weight.device,
+        )
+        layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
+            bit_width=bit_width,
+            weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
+            input_size=layer.mlp.dense_h_to_4h.input_size,
+            output_size=layer.mlp.dense_h_to_4h.output_size,
+            bias=True,
+            gather_output=False,
+            params_dtype=torch.half,
+            name="dense_h_to_4h",
+            skip_init=True,
+            device=layer.mlp.dense_h_to_4h.weight.device,
+        )
+        layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
+            bit_width=bit_width,
+            weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
+            input_size=layer.mlp.dense_4h_to_h.input_size,
+            output_size=layer.mlp.dense_4h_to_h.output_size,
+            bias=True,
+            input_is_parallel=True,
+            params_dtype=torch.half,
+            name="dense_h_to_4h",
+            skip_init=True,
+            device=layer.mlp.dense_4h_to_h.weight.device,
+        )
+
+    return model

+ 23 - 0
quantization/functional.py

@@ -0,0 +1,23 @@
+import torch
+
+
+class W8A16Linear(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor):
+        ctx.inp_shape = inp.size()
+        ctx.weight_shape = quant_w.size()
+        out_features = quant_w.size(0)
+        inp = inp.contiguous().view(-1, quant_w.size(1))
+        weight = quant_w.to(torch.half) * scale_w[:, None]
+        output = inp.mm(weight.t())
+        ctx.save_for_backward(inp, quant_w, scale_w)
+        return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
+
+    @staticmethod
+    def backward(ctx, grad_output: torch.Tensor):
+        inp, quant_w, scale_w = ctx.saved_tensors
+        weight = quant_w.to(torch.half) * scale_w[:, None]
+        grad_output = grad_output.contiguous().view(-1, weight.size(0))
+        grad_input = grad_output.mm(weight)
+        grad_weight = grad_output.t().mm(inp)
+        return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None

+ 78 - 0
quantization/layers.py

@@ -0,0 +1,78 @@
+import torch
+from torch.nn.parameter import Parameter
+
+from SwissArmyTransformer.mpu import copy_to_model_parallel_region
+from SwissArmyTransformer.mpu import gather_from_model_parallel_region
+from SwissArmyTransformer.mpu import reduce_from_model_parallel_region
+from SwissArmyTransformer.mpu import scatter_to_model_parallel_region
+from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
+
+from .functional import W8A16Linear
+
+
+class QuantizedColumnParallelLinear(ColumnParallelLinear):
+    def __init__(self, bit_width=8, weight=None, *args, **kwargs):
+        super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
+        self.bit_width = bit_width
+
+        shape = self.weight.shape
+        del self.weight
+
+        if weight is None:
+            self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
+            self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
+        else:
+            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
+            self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
+
+        self.weight = Parameter(self.weight, requires_grad=False)
+        self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
+
+    def forward(self, input_):
+        # Set up backprop all-reduce.
+        input_parallel = copy_to_model_parallel_region(input_)
+        # Matrix multiply.
+        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
+        if self.bias is not None:
+            output_parallel = output_parallel + self.bias
+        if self.gather_output:
+            # All-gather across the partitions.
+            output = gather_from_model_parallel_region(output_parallel)
+        else:
+            output = output_parallel
+        return output
+
+
+class QuantizedRowParallelLinear(RowParallelLinear):
+    def __init__(self, bit_width=8, weight=None, *args, **kwargs):
+        super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
+        self.bit_width = bit_width
+
+        shape = self.weight.shape
+        del self.weight
+
+        if weight is None:
+            self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
+            self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
+        else:
+            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
+            self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
+
+        self.weight = Parameter(self.weight, requires_grad=False)
+        self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
+
+    def forward(self, input_):
+        # Set up backprop all-reduce.
+        if self.input_is_parallel:
+            input_parallel = input_
+        else:
+            input_parallel = scatter_to_model_parallel_region(input_)
+        # Matrix multiply.
+        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
+        # All-reduce across all the partitions.
+        output_ = reduce_from_model_parallel_region(output_parallel)
+        if self.bias is not None:
+            output = output_ + self.bias
+        else:
+            output = output_
+        return output