2
0

functional.py 1.1 KB

1234567891011121314151617181920212223242526
  1. import torch
  2. from kernels import extract_weight_to_half
  3. class W8A16Linear(torch.autograd.Function):
  4. @staticmethod
  5. def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
  6. ctx.inp_shape = inp.size()
  7. ctx.weight_shape = quant_w.size()
  8. ctx.weight_bit_width = weight_bit_width
  9. out_features = quant_w.size(0)
  10. inp = inp.contiguous().view(-1, inp.size(-1))
  11. weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
  12. output = inp.mm(weight.t())
  13. ctx.save_for_backward(inp, quant_w, scale_w)
  14. return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
  15. @staticmethod
  16. def backward(ctx, grad_output: torch.Tensor):
  17. inp, quant_w, scale_w = ctx.saved_tensors
  18. weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
  19. grad_output = grad_output.contiguous().view(-1, weight.size(0))
  20. grad_input = grad_output.mm(weight)
  21. grad_weight = grad_output.t().mm(inp)
  22. return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None