123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import pkg_resources
- import torch
- import ctypes
- from typing import List
- from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
- RESOURCE_PACKAGE_NAME = __name__
- class Kernel:
- def __init__(self, filename: str, function_names: List[str]):
- filename = filename + ".fatbin"
- if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
- raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
- self.filename = filename
- self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
- self._function_names = function_names
- self._cmodule = LazyKernelCModule(self.code)
- for name in self._function_names:
- setattr(self, name, KernelFunction(self._cmodule, name))
- kernels = Kernel(
- "quantization",
- [
- "int4WeightCompression",
- "int4WeightExtractionFloat",
- "int4WeightExtractionHalf",
- "int8WeightExtractionFloat",
- "int8WeightExtractionHalf",
- ],
- )
- def compress_int4_weight(weight: torch.Tensor): # (n, m)
- with torch.cuda.device(weight.device):
- n, m = weight.size(0), weight.size(1)
- assert m % 2 == 0
- m = m // 2
- out = torch.empty(n, m, dtype=torch.int8, device="cuda")
- stream = torch.cuda.current_stream()
- gridDim = (n, 1, 1)
- blockDim = (min(round_up(m, 32), 1024), 1, 1)
- kernels.int4WeightCompression(
- gridDim,
- blockDim,
- 0,
- stream,
- [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
- )
- return out
- def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
- if source_bit_width == 8:
- func = kernels.int8WeightExtractionHalf
- elif source_bit_width == 4:
- func = kernels.int4WeightExtractionHalf
- else:
- assert False, "Unsupported bit-width"
- with torch.cuda.device(weight.device):
- n, m = weight.size(0), weight.size(1)
- out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
- stream = torch.cuda.current_stream()
- gridDim = (n, 1, 1)
- blockDim = (min(round_up(m, 32), 1024), 1, 1)
- func(
- gridDim,
- blockDim,
- 0,
- stream,
- [
- ctypes.c_void_p(weight.data_ptr()),
- ctypes.c_void_p(scale_list.data_ptr()),
- ctypes.c_void_p(out.data_ptr()),
- ctypes.c_int32(n),
- ctypes.c_int32(m),
- ],
- )
- return out
- if __name__ == "__main__":
- weight = torch.randn(4, 32).to(torch.int8).cuda()
- scale = torch.ones(weight.size(0)).to(torch.half).cuda()
- print(weight)
- b = compress_int4_weight(weight)
- print(b)
- a = extract_weight_to_half(b, scale, source_bit_width=4)
- print(a)
|