1234567891011121314151617181920212223242526 |
- import torch
- from kernels import extract_weight_to_half
- class W8A16Linear(torch.autograd.Function):
- @staticmethod
- def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
- ctx.inp_shape = inp.size()
- ctx.weight_shape = quant_w.size()
- ctx.weight_bit_width = weight_bit_width
- out_features = quant_w.size(0)
- inp = inp.contiguous().view(-1, inp.size(-1))
- weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
- output = inp.mm(weight.t())
- ctx.save_for_backward(inp, quant_w, scale_w)
- return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
- @staticmethod
- def backward(ctx, grad_output: torch.Tensor):
- inp, quant_w, scale_w = ctx.saved_tensors
- weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
- grad_output = grad_output.contiguous().view(-1, weight.size(0))
- grad_input = grad_output.mm(weight)
- grad_weight = grad_output.t().mm(inp)
- return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
|