123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import torch
- from .layers import QuantizedColumnParallelLinear
- from .layers import QuantizedRowParallelLinear
- def quantize(model, bit_width):
- """Replace fp16 linear with quantized linear"""
- for layer in model.transformer.layers:
- layer.attention.query_key_value = QuantizedColumnParallelLinear(
- bit_width=bit_width,
- weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
- input_size=layer.attention.query_key_value.input_size,
- output_size=layer.attention.query_key_value.output_size,
- bias=True,
- gather_output=False,
- params_dtype=torch.half,
- name="query_key_value",
- skip_init=True,
- device=layer.attention.query_key_value.weight.device,
- )
- layer.attention.dense = QuantizedRowParallelLinear(
- bit_width=bit_width,
- weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
- input_size=layer.attention.dense.input_size,
- output_size=layer.attention.dense.output_size,
- bias=True,
- input_is_parallel=True,
- params_dtype=torch.half,
- name="dense",
- skip_init=True,
- device=layer.attention.dense.weight.device,
- )
- layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
- bit_width=bit_width,
- weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
- input_size=layer.mlp.dense_h_to_4h.input_size,
- output_size=layer.mlp.dense_h_to_4h.output_size,
- bias=True,
- gather_output=False,
- params_dtype=torch.half,
- name="dense_h_to_4h",
- skip_init=True,
- device=layer.mlp.dense_h_to_4h.weight.device,
- )
- layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
- bit_width=bit_width,
- weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
- input_size=layer.mlp.dense_4h_to_h.input_size,
- output_size=layer.mlp.dense_4h_to_h.output_size,
- bias=True,
- input_is_parallel=True,
- params_dtype=torch.half,
- name="dense_h_to_4h",
- skip_init=True,
- device=layer.mlp.dense_4h_to_h.weight.device,
- )
- return model
|