2
0

layers.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from SwissArmyTransformer.mpu import copy_to_model_parallel_region
  4. from SwissArmyTransformer.mpu import gather_from_model_parallel_region
  5. from SwissArmyTransformer.mpu import reduce_from_model_parallel_region
  6. from SwissArmyTransformer.mpu import scatter_to_model_parallel_region
  7. from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
  8. from .functional import W8A16Linear
  9. from kernels import compress_int4_weight
  10. class QuantizedColumnParallelLinear(ColumnParallelLinear):
  11. def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
  12. super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
  13. self.weight_bit_width = weight_bit_width
  14. shape = self.weight.shape
  15. del self.weight
  16. if weight is None:
  17. self.weight = torch.empty(
  18. shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
  19. )
  20. self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
  21. else:
  22. self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
  23. self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
  24. if weight_bit_width == 4:
  25. self.weight = compress_int4_weight(self.weight)
  26. self.weight = Parameter(self.weight, requires_grad=False)
  27. self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
  28. def forward(self, input_):
  29. # Set up backprop all-reduce.
  30. input_parallel = copy_to_model_parallel_region(input_)
  31. # Matrix multiply.
  32. output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
  33. if self.bias is not None:
  34. output_parallel = output_parallel + self.bias
  35. if self.gather_output:
  36. # All-gather across the partitions.
  37. output = gather_from_model_parallel_region(output_parallel)
  38. else:
  39. output = output_parallel
  40. return output
  41. class QuantizedRowParallelLinear(RowParallelLinear):
  42. def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
  43. super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
  44. self.weight_bit_width = weight_bit_width
  45. shape = self.weight.shape
  46. del self.weight
  47. if weight is None:
  48. self.weight = torch.empty(
  49. shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
  50. )
  51. self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
  52. else:
  53. self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
  54. self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
  55. if weight_bit_width == 4:
  56. self.weight = compress_int4_weight(self.weight)
  57. self.weight = Parameter(self.weight, requires_grad=False)
  58. self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
  59. def forward(self, input_):
  60. # Set up backprop all-reduce.
  61. if self.input_is_parallel:
  62. input_parallel = input_
  63. else:
  64. input_parallel = scatter_to_model_parallel_region(input_)
  65. # Matrix multiply.
  66. output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
  67. # All-reduce across all the partitions.
  68. output_ = reduce_from_model_parallel_region(output_parallel)
  69. if self.bias is not None:
  70. output = output_ + self.bias
  71. else:
  72. output = output_
  73. return output