2
0

__init__.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import pkg_resources
  2. import torch
  3. import ctypes
  4. from typing import List
  5. from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
  6. RESOURCE_PACKAGE_NAME = __name__
  7. class Kernel:
  8. def __init__(self, filename: str, function_names: List[str]):
  9. filename = filename + ".fatbin"
  10. if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
  11. raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
  12. self.filename = filename
  13. self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
  14. self._function_names = function_names
  15. self._cmodule = LazyKernelCModule(self.code)
  16. for name in self._function_names:
  17. setattr(self, name, KernelFunction(self._cmodule, name))
  18. kernels = Kernel(
  19. "quantization",
  20. [
  21. "int4WeightCompression",
  22. "int4WeightExtractionFloat",
  23. "int4WeightExtractionHalf",
  24. "int8WeightExtractionFloat",
  25. "int8WeightExtractionHalf",
  26. ],
  27. )
  28. def compress_int4_weight(weight: torch.Tensor): # (n, m)
  29. with torch.cuda.device(weight.device):
  30. n, m = weight.size(0), weight.size(1)
  31. assert m % 2 == 0
  32. m = m // 2
  33. out = torch.empty(n, m, dtype=torch.int8, device="cuda")
  34. stream = torch.cuda.current_stream()
  35. gridDim = (n, 1, 1)
  36. blockDim = (min(round_up(m, 32), 1024), 1, 1)
  37. kernels.int4WeightCompression(
  38. gridDim,
  39. blockDim,
  40. 0,
  41. stream,
  42. [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
  43. )
  44. return out
  45. def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
  46. if source_bit_width == 8:
  47. func = kernels.int8WeightExtractionHalf
  48. elif source_bit_width == 4:
  49. func = kernels.int4WeightExtractionHalf
  50. else:
  51. assert False, "Unsupported bit-width"
  52. with torch.cuda.device(weight.device):
  53. n, m = weight.size(0), weight.size(1)
  54. out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
  55. stream = torch.cuda.current_stream()
  56. gridDim = (n, 1, 1)
  57. blockDim = (min(round_up(m, 32), 1024), 1, 1)
  58. func(
  59. gridDim,
  60. blockDim,
  61. 0,
  62. stream,
  63. [
  64. ctypes.c_void_p(weight.data_ptr()),
  65. ctypes.c_void_p(scale_list.data_ptr()),
  66. ctypes.c_void_p(out.data_ptr()),
  67. ctypes.c_int32(n),
  68. ctypes.c_int32(m),
  69. ],
  70. )
  71. return out
  72. if __name__ == "__main__":
  73. weight = torch.randn(4, 32).to(torch.int8).cuda()
  74. scale = torch.ones(weight.size(0)).to(torch.half).cuda()
  75. print(weight)
  76. b = compress_int4_weight(weight)
  77. print(b)
  78. a = extract_weight_to_half(b, scale, source_bit_width=4)
  79. print(a)