|
@@ -8,22 +8,27 @@ from SwissArmyTransformer.mpu import scatter_to_model_parallel_region
|
|
|
from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
|
|
|
|
|
|
from .functional import W8A16Linear
|
|
|
+from kernels import compress_int4_weight
|
|
|
|
|
|
|
|
|
class QuantizedColumnParallelLinear(ColumnParallelLinear):
|
|
|
- def __init__(self, bit_width=8, weight=None, *args, **kwargs):
|
|
|
+ def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
|
|
|
super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
|
|
|
- self.bit_width = bit_width
|
|
|
+ self.weight_bit_width = weight_bit_width
|
|
|
|
|
|
shape = self.weight.shape
|
|
|
del self.weight
|
|
|
|
|
|
if weight is None:
|
|
|
- self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
|
|
|
+ self.weight = torch.empty(
|
|
|
+ shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
|
|
+ )
|
|
|
self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
|
|
|
else:
|
|
|
- self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
|
|
|
+ self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
|
|
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
|
|
+ if weight_bit_width == 4:
|
|
|
+ self.weight = compress_int4_weight(self.weight)
|
|
|
|
|
|
self.weight = Parameter(self.weight, requires_grad=False)
|
|
|
self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
|
|
@@ -32,7 +37,7 @@ class QuantizedColumnParallelLinear(ColumnParallelLinear):
|
|
|
# Set up backprop all-reduce.
|
|
|
input_parallel = copy_to_model_parallel_region(input_)
|
|
|
# Matrix multiply.
|
|
|
- output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
|
|
|
+ output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
|
|
|
if self.bias is not None:
|
|
|
output_parallel = output_parallel + self.bias
|
|
|
if self.gather_output:
|
|
@@ -44,19 +49,23 @@ class QuantizedColumnParallelLinear(ColumnParallelLinear):
|
|
|
|
|
|
|
|
|
class QuantizedRowParallelLinear(RowParallelLinear):
|
|
|
- def __init__(self, bit_width=8, weight=None, *args, **kwargs):
|
|
|
+ def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
|
|
|
super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
|
|
|
- self.bit_width = bit_width
|
|
|
+ self.weight_bit_width = weight_bit_width
|
|
|
|
|
|
shape = self.weight.shape
|
|
|
del self.weight
|
|
|
|
|
|
if weight is None:
|
|
|
- self.weight = torch.empty(shape[0] * bit_width // 8, shape[1], dtype=torch.int8, device=kwargs["device"])
|
|
|
+ self.weight = torch.empty(
|
|
|
+ shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
|
|
+ )
|
|
|
self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
|
|
|
else:
|
|
|
- self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (bit_width - 1)) - 1)).half()
|
|
|
+ self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
|
|
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
|
|
+ if weight_bit_width == 4:
|
|
|
+ self.weight = compress_int4_weight(self.weight)
|
|
|
|
|
|
self.weight = Parameter(self.weight, requires_grad=False)
|
|
|
self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
|
|
@@ -68,7 +77,7 @@ class QuantizedRowParallelLinear(RowParallelLinear):
|
|
|
else:
|
|
|
input_parallel = scatter_to_model_parallel_region(input_)
|
|
|
# Matrix multiply.
|
|
|
- output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale)
|
|
|
+ output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
|
|
|
# All-reduce across all the partitions.
|
|
|
output_ = reduce_from_model_parallel_region(output_parallel)
|
|
|
if self.bias is not None:
|