functional.py 981 B

1234567891011121314151617181920212223
  1. import torch
  2. class W8A16Linear(torch.autograd.Function):
  3. @staticmethod
  4. def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor):
  5. ctx.inp_shape = inp.size()
  6. ctx.weight_shape = quant_w.size()
  7. out_features = quant_w.size(0)
  8. inp = inp.contiguous().view(-1, quant_w.size(1))
  9. weight = quant_w.to(torch.half) * scale_w[:, None]
  10. output = inp.mm(weight.t())
  11. ctx.save_for_backward(inp, quant_w, scale_w)
  12. return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
  13. @staticmethod
  14. def backward(ctx, grad_output: torch.Tensor):
  15. inp, quant_w, scale_w = ctx.saved_tensors
  16. weight = quant_w.to(torch.half) * scale_w[:, None]
  17. grad_output = grad_output.contiguous().view(-1, weight.size(0))
  18. grad_input = grad_output.mm(weight)
  19. grad_weight = grad_output.t().mm(inp)
  20. return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None