2
0

layers.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from SwissArmyTransformer.mpu import copy_to_model_parallel_region
  4. from SwissArmyTransformer.mpu import gather_from_model_parallel_region
  5. from SwissArmyTransformer.mpu import reduce_from_model_parallel_region
  6. from SwissArmyTransformer.mpu import scatter_to_model_parallel_region
  7. from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
  8. from .functional import W8A16Linear
  9. class QuantizedColumnParallelLinear(ColumnParallelLinear):
  10. def __init__(self, bit_width=8, weight=None, *args, **kwargs):
  11. super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
  12. self.bit_width = bit_width
  13. shape = self.weight.shape
  14. del self.weight
  15. if weight is None:
  16. self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
  17. self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
  18. else:
  19. self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
  20. self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
  21. self.weight = Parameter(self.weight, requires_grad=False)
  22. self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
  23. def forward(self, input_):
  24. # Set up backprop all-reduce.
  25. input_parallel = copy_to_model_parallel_region(input_)
  26. # Matrix multiply.
  27. output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
  28. if self.bias is not None:
  29. output_parallel = output_parallel + self.bias
  30. if self.gather_output:
  31. # All-gather across the partitions.
  32. output = gather_from_model_parallel_region(output_parallel)
  33. else:
  34. output = output_parallel
  35. return output
  36. class QuantizedRowParallelLinear(RowParallelLinear):
  37. def __init__(self, bit_width=8, weight=None, *args, **kwargs):
  38. super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
  39. self.bit_width = bit_width
  40. shape = self.weight.shape
  41. del self.weight
  42. if weight is None:
  43. self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
  44. self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
  45. else:
  46. self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
  47. self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
  48. self.weight = Parameter(self.weight, requires_grad=False)
  49. self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
  50. def forward(self, input_):
  51. # Set up backprop all-reduce.
  52. if self.input_is_parallel:
  53. input_parallel = input_
  54. else:
  55. input_parallel = scatter_to_model_parallel_region(input_)
  56. # Matrix multiply.
  57. output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
  58. # All-reduce across all the partitions.
  59. output_ = reduce_from_model_parallel_region(output_parallel)
  60. if self.bias is not None:
  61. output = output_ + self.bias
  62. else:
  63. output = output_
  64. return output