| 1234567891011121314151617181920212223242526 | import torchfrom kernels import extract_weight_to_halfclass W8A16Linear(torch.autograd.Function):    @staticmethod    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):        ctx.inp_shape = inp.size()        ctx.weight_shape = quant_w.size()        ctx.weight_bit_width = weight_bit_width        out_features = quant_w.size(0)        inp = inp.contiguous().view(-1, inp.size(-1))        weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)        output = inp.mm(weight.t())        ctx.save_for_backward(inp, quant_w, scale_w)        return output.view(*(ctx.inp_shape[:-1] + (out_features,)))    @staticmethod    def backward(ctx, grad_output: torch.Tensor):        inp, quant_w, scale_w = ctx.saved_tensors        weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)        grad_output = grad_output.contiguous().view(-1, weight.size(0))        grad_input = grad_output.mm(weight)        grad_weight = grad_output.t().mm(inp)        return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
 |