| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 | 
							- import torch
 
- from torch.nn.parameter import Parameter
 
- from SwissArmyTransformer.mpu import copy_to_model_parallel_region
 
- from SwissArmyTransformer.mpu import gather_from_model_parallel_region
 
- from SwissArmyTransformer.mpu import reduce_from_model_parallel_region
 
- from SwissArmyTransformer.mpu import scatter_to_model_parallel_region
 
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
 
- from .functional import W8A16Linear
 
- from kernels import compress_int4_weight
 
- class QuantizedColumnParallelLinear(ColumnParallelLinear):
 
-     def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
 
-         super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
 
-         self.weight_bit_width = weight_bit_width
 
-         shape = self.weight.shape
 
-         del self.weight
 
-         if weight is None:
 
-             self.weight = torch.empty(
 
-                 shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
 
-             )
 
-             self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
 
-         else:
 
-             self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
 
-             self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
 
-             if weight_bit_width == 4:
 
-                 self.weight = compress_int4_weight(self.weight)
 
-         self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
 
-         self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
 
-     def forward(self, input_):
 
-         # Set up backprop all-reduce.
 
-         input_parallel = copy_to_model_parallel_region(input_)
 
-         # Matrix multiply.
 
-         output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
 
-         if self.bias is not None:
 
-             output_parallel = output_parallel + self.bias
 
-         if self.gather_output:
 
-             # All-gather across the partitions.
 
-             output = gather_from_model_parallel_region(output_parallel)
 
-         else:
 
-             output = output_parallel
 
-         return output
 
- class QuantizedRowParallelLinear(RowParallelLinear):
 
-     def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
 
-         super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
 
-         self.weight_bit_width = weight_bit_width
 
-         shape = self.weight.shape
 
-         del self.weight
 
-         if weight is None:
 
-             self.weight = torch.empty(
 
-                 shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
 
-             )
 
-             self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
 
-         else:
 
-             self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
 
-             self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
 
-             if weight_bit_width == 4:
 
-                 self.weight = compress_int4_weight(self.weight)
 
-         self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
 
-         self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
 
-     def forward(self, input_):
 
-         # Set up backprop all-reduce.
 
-         if self.input_is_parallel:
 
-             input_parallel = input_
 
-         else:
 
-             input_parallel = scatter_to_model_parallel_region(input_)
 
-         # Matrix multiply.
 
-         output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
 
-         # All-reduce across all the partitions.
 
-         output_ = reduce_from_model_parallel_region(output_parallel)
 
-         if self.bias is not None:
 
-             output = output_ + self.bias
 
-         else:
 
-             output = output_
 
-         return output
 
 
  |