__init__.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. from .layers import QuantizedColumnParallelLinear
  3. from .layers import QuantizedRowParallelLinear
  4. def quantize(model, weight_bit_width):
  5. """Replace fp16 linear with quantized linear"""
  6. for layer in model.transformer.layers:
  7. layer.attention.query_key_value = QuantizedColumnParallelLinear(
  8. weight_bit_width=weight_bit_width,
  9. weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
  10. input_size=layer.attention.query_key_value.input_size,
  11. output_size=layer.attention.query_key_value.output_size,
  12. bias=True,
  13. gather_output=False,
  14. params_dtype=torch.half,
  15. name="query_key_value",
  16. skip_init=True,
  17. device=layer.attention.query_key_value.weight.device,
  18. )
  19. layer.attention.dense = QuantizedRowParallelLinear(
  20. weight_bit_width=weight_bit_width,
  21. weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
  22. input_size=layer.attention.dense.input_size,
  23. output_size=layer.attention.dense.output_size,
  24. bias=True,
  25. input_is_parallel=True,
  26. params_dtype=torch.half,
  27. name="dense",
  28. skip_init=True,
  29. device=layer.attention.dense.weight.device,
  30. )
  31. layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
  32. weight_bit_width=weight_bit_width,
  33. weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
  34. input_size=layer.mlp.dense_h_to_4h.input_size,
  35. output_size=layer.mlp.dense_h_to_4h.output_size,
  36. bias=True,
  37. gather_output=False,
  38. params_dtype=torch.half,
  39. name="dense_h_to_4h",
  40. skip_init=True,
  41. device=layer.mlp.dense_h_to_4h.weight.device,
  42. )
  43. layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
  44. weight_bit_width=weight_bit_width,
  45. weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
  46. input_size=layer.mlp.dense_4h_to_h.input_size,
  47. output_size=layer.mlp.dense_4h_to_h.output_size,
  48. bias=True,
  49. input_is_parallel=True,
  50. params_dtype=torch.half,
  51. name="dense_h_to_4h",
  52. skip_init=True,
  53. device=layer.mlp.dense_4h_to_h.weight.device,
  54. )
  55. return model