quantization.cu 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #include <cuda_fp16.h>
  2. template<typename T>
  3. __device__ void
  4. int4WeightExtractionDevice(const int8_t* weight,
  5. const T* scale_list,
  6. T* output,
  7. const int n,
  8. const int k)
  9. {
  10. for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
  11. int8_t original = weight[i];
  12. int8_t high = original >> 4;
  13. int8_t low = original << 4; low = low >> 4;
  14. output[i * 2] = T(high) * scale_list[blockIdx.x];
  15. output[i * 2 + 1] = T(low) * scale_list[blockIdx.x];
  16. }
  17. }
  18. __device__ void
  19. int4WeightCompressionDevice(const int8_t* input,
  20. int8_t* output,
  21. const int n,
  22. const int k)
  23. {
  24. for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
  25. output[i] = (input[i * 2] << 4) | (input[i * 2 + 1] & 0b00001111);
  26. }
  27. }
  28. template<typename T>
  29. __device__ void
  30. int8WeightExtractionDevice(const int8_t* weight,
  31. const T* scale_list,
  32. T* output,
  33. const int n,
  34. const int k)
  35. {
  36. for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
  37. output[i] = T(weight[i]) * scale_list[blockIdx.x];
  38. }
  39. }
  40. extern "C" __global__ void int4WeightExtractionHalf(const int8_t* weight,
  41. const half* scale_list,
  42. half* output,
  43. const int n,
  44. const int k){
  45. int4WeightExtractionDevice<half>(weight, scale_list, output, n, k);
  46. }
  47. extern "C" __global__ void int4WeightExtractionFloat(const int8_t* weight,
  48. const float* scale_list,
  49. float* output,
  50. const int n,
  51. const int k){
  52. int4WeightExtractionDevice<float>(weight, scale_list, output, n, k);
  53. }
  54. extern "C" __global__ void int8WeightExtractionHalf(const int8_t* weight,
  55. const half* scale_list,
  56. half* output,
  57. const int n,
  58. const int k){
  59. int8WeightExtractionDevice<half>(weight, scale_list, output, n, k);
  60. }
  61. extern "C" __global__ void int8WeightExtractionFloat(const int8_t* weight,
  62. const float* scale_list,
  63. float* output,
  64. const int n,
  65. const int k){
  66. int8WeightExtractionDevice<float>(weight, scale_list, output, n, k);
  67. }
  68. extern "C" __global__ void int4WeightCompression(const int8_t* input,
  69. int8_t* output,
  70. const int n,
  71. const int k){
  72. int4WeightCompressionDevice(input, output, n, k);
  73. }