Pārlūkot izejas kodu

Add load from quantized checkpoint

Sengxian 3 gadi atpakaļ
vecāks
revīzija
21cadf7677
1 mainītis faili ar 9 papildinājumiem un 1 dzēšanām
  1. 9 1
      initialize.py

+ 9 - 1
initialize.py

@@ -23,6 +23,9 @@ def add_quantization_args(parser):
     group = parser.add_argument_group("Quantization")
 
     group.add_argument("--quantization-bit-width", type=int, default=None)
+    group.add_argument(
+        "--load-from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint"
+    )
 
 
 def initialize(extra_args_provider):
@@ -45,6 +48,11 @@ def initialize_model_and_tokenizer(args):
     # Initialize model
     model = GLM130B(args).half()
 
+    if args.load_from_quantized_checkpoint:
+        assert not args.bminf and args.quantization_bit_width is not None
+        # Quantize model before moving to GPU
+        model = quantize(model, args.quantization_bit_width)
+
     # Load checkpoint
     torch.distributed.barrier()
     start = time.time()
@@ -59,7 +67,7 @@ def initialize_model_and_tokenizer(args):
         with torch.cuda.device(args.device):
             model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
     else:
-        if args.quantization_bit_width is not None:
+        if args.quantization_bit_width is not None and not args.load_from_quantized_checkpoint:
             # Quantize model before moving to GPU
             model = quantize(model, args.quantization_bit_width)
         model = model.to(args.device)