Kaynağa Gözat

Fix quantization argument bug

Sengxian 3 yıl önce
ebeveyn
işleme
c64d6ea33c
2 değiştirilmiş dosya ile 6 ekleme ve 5 silme
  1. 3 5
      initialize.py
  2. 3 0
      quantization/__init__.py

+ 3 - 5
initialize.py

@@ -23,9 +23,7 @@ def add_quantization_args(parser):
     group = parser.add_argument_group("Quantization")
 
     group.add_argument("--quantization-bit-width", type=int, default=None)
-    group.add_argument(
-        "--load-from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint"
-    )
+    group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
 
 
 def initialize(extra_args_provider):
@@ -48,7 +46,7 @@ def initialize_model_and_tokenizer(args):
     # Initialize model
     model = GLM130B(args).half()
 
-    if args.load_from_quantized_checkpoint:
+    if args.from_quantized_checkpoint:
         assert not args.bminf and args.quantization_bit_width is not None
         # Quantize model before moving to GPU
         model = quantize(model, args.quantization_bit_width)
@@ -67,7 +65,7 @@ def initialize_model_and_tokenizer(args):
         with torch.cuda.device(args.device):
             model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
     else:
-        if args.quantization_bit_width is not None and not args.load_from_quantized_checkpoint:
+        if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
             # Quantize model before moving to GPU
             model = quantize(model, args.quantization_bit_width)
         model = model.to(args.device)

+ 3 - 0
quantization/__init__.py

@@ -7,6 +7,9 @@ from .layers import QuantizedRowParallelLinear
 def quantize(model, weight_bit_width):
     """Replace fp16 linear with quantized linear"""
 
+    if torch.distributed.get_rank() == 0:
+        print(f"> Quantizing model weight to {weight_bit_width} bits")
+
     for layer in model.transformer.layers:
         layer.attention.query_key_value = QuantizedColumnParallelLinear(
             weight_bit_width=weight_bit_width,