__init__.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import torch
  2. from .layers import QuantizedColumnParallelLinear
  3. from .layers import QuantizedRowParallelLinear
  4. def quantize(model, weight_bit_width):
  5. """Replace fp16 linear with quantized linear"""
  6. if torch.distributed.get_rank() == 0:
  7. print(f"> Quantizing model weight to {weight_bit_width} bits")
  8. for layer in model.transformer.layers:
  9. layer.attention.query_key_value = QuantizedColumnParallelLinear(
  10. weight_bit_width=weight_bit_width,
  11. weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
  12. input_size=layer.attention.query_key_value.input_size,
  13. output_size=layer.attention.query_key_value.output_size,
  14. bias=True,
  15. gather_output=False,
  16. params_dtype=torch.half,
  17. name="query_key_value",
  18. skip_init=True,
  19. device=layer.attention.query_key_value.weight.device,
  20. )
  21. layer.attention.dense = QuantizedRowParallelLinear(
  22. weight_bit_width=weight_bit_width,
  23. weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
  24. input_size=layer.attention.dense.input_size,
  25. output_size=layer.attention.dense.output_size,
  26. bias=True,
  27. input_is_parallel=True,
  28. params_dtype=torch.half,
  29. name="dense",
  30. skip_init=True,
  31. device=layer.attention.dense.weight.device,
  32. )
  33. layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
  34. weight_bit_width=weight_bit_width,
  35. weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
  36. input_size=layer.mlp.dense_h_to_4h.input_size,
  37. output_size=layer.mlp.dense_h_to_4h.output_size,
  38. bias=True,
  39. gather_output=False,
  40. params_dtype=torch.half,
  41. name="dense_h_to_4h",
  42. skip_init=True,
  43. device=layer.mlp.dense_h_to_4h.weight.device,
  44. )
  45. layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
  46. weight_bit_width=weight_bit_width,
  47. weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
  48. input_size=layer.mlp.dense_4h_to_h.input_size,
  49. output_size=layer.mlp.dense_4h_to_h.output_size,
  50. bias=True,
  51. input_is_parallel=True,
  52. params_dtype=torch.half,
  53. name="dense_h_to_4h",
  54. skip_init=True,
  55. device=layer.mlp.dense_4h_to_h.weight.device,
  56. )
  57. return model