Quellcode durchsuchen

Add 4-bit quantization and CUDA kernels

Sengxian vor 3 Jahren
Ursprung
Commit
96eac9f33b

+ 22 - 0
cuda/Makefile

@@ -0,0 +1,22 @@
+NVCC=nvcc
+OPTIONS=-gencode arch=compute_61,code=sm_61 \
+		-gencode arch=compute_62,code=sm_62 \
+		-gencode arch=compute_70,code=sm_70 \
+		-gencode arch=compute_72,code=sm_72 \
+		-gencode arch=compute_75,code=sm_75 \
+		-gencode arch=compute_80,code=sm_80 \
+		-gencode arch=compute_86,code=sm_86
+
+TARGETS=$(patsubst %.cu, %.fatbin, $(wildcard *.cu))
+
+all: $(TARGETS)
+
+%.fatbin: %.cu
+	$(NVCC) -fatbin $^ $(OPTIONS) -o $@
+
+.PHONY : clean, copy
+clean:
+	rm $(TARGETS)
+
+copy:
+	cp $(TARGETS) ../kernels/

+ 81 - 0
cuda/quantization.cu

@@ -0,0 +1,81 @@
+#include <cuda_fp16.h>
+
+template<typename T>
+__device__ void
+int4WeightExtractionDevice(const int8_t* weight,
+                                const T* scale_list,
+                                T* output,
+                                const int n,
+                                const int k)
+{
+    for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
+        int8_t original = weight[i];
+        int8_t high = original >> 4;
+        int8_t low = original << 4; low = low >> 4;
+        output[i * 2] = T(high) * scale_list[blockIdx.x];
+        output[i * 2 + 1] = T(low) * scale_list[blockIdx.x];
+    }
+}
+
+__device__ void
+int4WeightCompressionDevice(const int8_t* input,
+                                int8_t* output,
+                                const int n,
+                                const int k)
+{
+    for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
+        output[i] = (input[i * 2] << 4) | (input[i * 2 + 1] & 0b00001111);
+    }
+}
+
+template<typename T>
+__device__ void
+int8WeightExtractionDevice(const int8_t* weight,
+                                const T* scale_list,
+                                T* output,
+                                const int n,
+                                const int k)
+{
+    for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
+        output[i] = T(weight[i]) * scale_list[blockIdx.x];
+    }
+}
+
+extern "C" __global__ void int4WeightExtractionHalf(const int8_t* weight,
+                                const half* scale_list,
+                                half* output,
+                                const int n,
+                                const int k){
+                                    int4WeightExtractionDevice<half>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int4WeightExtractionFloat(const int8_t* weight,
+                                const float* scale_list,
+                                float* output,
+                                const int n,
+                                const int k){
+                                    int4WeightExtractionDevice<float>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int8WeightExtractionHalf(const int8_t* weight,
+                                const half* scale_list,
+                                half* output,
+                                const int n,
+                                const int k){
+                                    int8WeightExtractionDevice<half>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int8WeightExtractionFloat(const int8_t* weight,
+                                const float* scale_list,
+                                float* output,
+                                const int n,
+                                const int k){
+                                    int8WeightExtractionDevice<float>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int4WeightCompression(const int8_t* input,
+                                int8_t* output,
+                                const int n,
+                                const int k){
+                                    int4WeightCompressionDevice(input, output, n, k);
+                                }

+ 99 - 0
kernels/__init__.py

@@ -0,0 +1,99 @@
+import pkg_resources
+import torch
+import ctypes
+
+from typing import List
+from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
+
+RESOURCE_PACKAGE_NAME = __name__
+
+
+class Kernel:
+    def __init__(self, filename: str, function_names: List[str]):
+        filename = filename + ".fatbin"
+        if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
+            raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
+        self.filename = filename
+        self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
+        self._function_names = function_names
+        self._cmodule = LazyKernelCModule(self.code)
+
+        for name in self._function_names:
+            setattr(self, name, KernelFunction(self._cmodule, name))
+
+
+kernels = Kernel(
+    "quantization",
+    [
+        "int4WeightCompression",
+        "int4WeightExtractionFloat",
+        "int4WeightExtractionHalf",
+        "int8WeightExtractionFloat",
+        "int8WeightExtractionHalf",
+    ],
+)
+
+
+def compress_int4_weight(weight: torch.Tensor):  # (n, m)
+    with torch.cuda.device(weight.device):
+        n, m = weight.size(0), weight.size(1)
+        assert m % 2 == 0
+        m = m // 2
+        out = torch.empty(n, m, dtype=torch.int8, device="cuda")
+        stream = torch.cuda.current_stream()
+
+        gridDim = (n, 1, 1)
+        blockDim = (min(round_up(m, 32), 1024), 1, 1)
+
+        kernels.int4WeightCompression(
+            gridDim,
+            blockDim,
+            0,
+            stream,
+            [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
+        )
+        return out
+
+
+def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
+    if source_bit_width == 8:
+        func = kernels.int8WeightExtractionHalf
+    elif source_bit_width == 4:
+        func = kernels.int4WeightExtractionHalf
+    else:
+        assert False, "Unsupported bit-width"
+
+    with torch.cuda.device(weight.device):
+        n, m = weight.size(0), weight.size(1)
+        out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
+        stream = torch.cuda.current_stream()
+
+        gridDim = (n, 1, 1)
+        blockDim = (min(round_up(m, 32), 1024), 1, 1)
+
+        func(
+            gridDim,
+            blockDim,
+            0,
+            stream,
+            [
+                ctypes.c_void_p(weight.data_ptr()),
+                ctypes.c_void_p(scale_list.data_ptr()),
+                ctypes.c_void_p(out.data_ptr()),
+                ctypes.c_int32(n),
+                ctypes.c_int32(m),
+            ],
+        )
+        return out
+
+
+if __name__ == "__main__":
+    weight = torch.randn(4, 32).to(torch.int8).cuda()
+    scale = torch.ones(weight.size(0)).to(torch.half).cuda()
+
+    print(weight)
+    b = compress_int4_weight(weight)
+    print(b)
+
+    a = extract_weight_to_half(b, scale, source_bit_width=4)
+    print(a)

BIN
kernels/quantization.fatbin


+ 5 - 5
quantization/__init__.py

@@ -4,12 +4,12 @@ from .layers import QuantizedColumnParallelLinear
 from .layers import QuantizedRowParallelLinear
 
 
-def quantize(model, bit_width):
+def quantize(model, weight_bit_width):
     """Replace fp16 linear with quantized linear"""
 
     for layer in model.transformer.layers:
         layer.attention.query_key_value = QuantizedColumnParallelLinear(
-            bit_width=bit_width,
+            weight_bit_width=weight_bit_width,
             weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
             input_size=layer.attention.query_key_value.input_size,
             output_size=layer.attention.query_key_value.output_size,
@@ -21,7 +21,7 @@ def quantize(model, bit_width):
             device=layer.attention.query_key_value.weight.device,
         )
         layer.attention.dense = QuantizedRowParallelLinear(
-            bit_width=bit_width,
+            weight_bit_width=weight_bit_width,
             weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
             input_size=layer.attention.dense.input_size,
             output_size=layer.attention.dense.output_size,
@@ -33,7 +33,7 @@ def quantize(model, bit_width):
             device=layer.attention.dense.weight.device,
         )
         layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
-            bit_width=bit_width,
+            weight_bit_width=weight_bit_width,
             weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
             input_size=layer.mlp.dense_h_to_4h.input_size,
             output_size=layer.mlp.dense_h_to_4h.output_size,
@@ -45,7 +45,7 @@ def quantize(model, bit_width):
             device=layer.mlp.dense_h_to_4h.weight.device,
         )
         layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
-            bit_width=bit_width,
+            weight_bit_width=weight_bit_width,
             weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
             input_size=layer.mlp.dense_4h_to_h.input_size,
             output_size=layer.mlp.dense_4h_to_h.output_size,

+ 7 - 4
quantization/functional.py

@@ -1,14 +1,17 @@
 import torch
 
+from kernels import extract_weight_to_half
+
 
 class W8A16Linear(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor):
+    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
         ctx.inp_shape = inp.size()
         ctx.weight_shape = quant_w.size()
+        ctx.weight_bit_width = weight_bit_width
         out_features = quant_w.size(0)
-        inp = inp.contiguous().view(-1, quant_w.size(1))
-        weight = quant_w.to(torch.half) * scale_w[:, None]
+        inp = inp.contiguous().view(-1, inp.size(-1))
+        weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
         output = inp.mm(weight.t())
         ctx.save_for_backward(inp, quant_w, scale_w)
         return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@@ -16,7 +19,7 @@ class W8A16Linear(torch.autograd.Function):
     @staticmethod
     def backward(ctx, grad_output: torch.Tensor):
         inp, quant_w, scale_w = ctx.saved_tensors
-        weight = quant_w.to(torch.half) * scale_w[:, None]
+        weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
         grad_output = grad_output.contiguous().view(-1, weight.size(0))
         grad_input = grad_output.mm(weight)
         grad_weight = grad_output.t().mm(inp)

+ 19 - 10
quantization/layers.py

@@ -8,22 +8,27 @@ from SwissArmyTransformer.mpu import scatter_to_model_parallel_region
 from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
 
 from .functional import W8A16Linear
+from kernels import compress_int4_weight
 
 
 class QuantizedColumnParallelLinear(ColumnParallelLinear):
-    def __init__(self, bit_width=8, weight=None, *args, **kwargs):
+    def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
         super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
-        self.bit_width = bit_width
+        self.weight_bit_width = weight_bit_width
 
         shape = self.weight.shape
         del self.weight
 
         if weight is None:
-            self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
+            self.weight = torch.empty(
+                shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
+            )
             self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
         else:
-            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
+            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
             self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
+            if weight_bit_width == 4:
+                self.weight = compress_int4_weight(self.weight)
 
         self.weight = Parameter(self.weight, requires_grad=False)
         self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
@@ -32,7 +37,7 @@ class QuantizedColumnParallelLinear(ColumnParallelLinear):
         # Set up backprop all-reduce.
         input_parallel = copy_to_model_parallel_region(input_)
         # Matrix multiply.
-        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
+        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
         if self.bias is not None:
             output_parallel = output_parallel + self.bias
         if self.gather_output:
@@ -44,19 +49,23 @@ class QuantizedColumnParallelLinear(ColumnParallelLinear):
 
 
 class QuantizedRowParallelLinear(RowParallelLinear):
-    def __init__(self, bit_width=8, weight=None, *args, **kwargs):
+    def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
         super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
-        self.bit_width = bit_width
+        self.weight_bit_width = weight_bit_width
 
         shape = self.weight.shape
         del self.weight
 
         if weight is None:
-            self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
+            self.weight = torch.empty(
+                shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
+            )
             self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
         else:
-            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
+            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
             self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
+            if weight_bit_width == 4:
+                self.weight = compress_int4_weight(self.weight)
 
         self.weight = Parameter(self.weight, requires_grad=False)
         self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
@@ -68,7 +77,7 @@ class QuantizedRowParallelLinear(RowParallelLinear):
         else:
             input_parallel = scatter_to_model_parallel_region(input_)
         # Matrix multiply.
-        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
+        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
         # All-reduce across all the partitions.
         output_ = reduce_from_model_parallel_region(output_parallel)
         if self.bias is not None:

+ 2 - 1
requirements.txt

@@ -2,4 +2,5 @@ SwissArmyTransformer>=0.2.11
 icetk
 apex
 scipy
-dataclass_wizard
+dataclass_wizard
+cpm_kernels