|
@@ -23,6 +23,9 @@ def add_quantization_args(parser):
|
|
group = parser.add_argument_group("Quantization")
|
|
group = parser.add_argument_group("Quantization")
|
|
|
|
|
|
group.add_argument("--quantization-bit-width", type=int, default=None)
|
|
group.add_argument("--quantization-bit-width", type=int, default=None)
|
|
|
|
+ group.add_argument(
|
|
|
|
+ "--load-from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint"
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
def initialize(extra_args_provider):
|
|
def initialize(extra_args_provider):
|
|
@@ -45,6 +48,11 @@ def initialize_model_and_tokenizer(args):
|
|
# Initialize model
|
|
# Initialize model
|
|
model = GLM130B(args).half()
|
|
model = GLM130B(args).half()
|
|
|
|
|
|
|
|
+ if args.load_from_quantized_checkpoint:
|
|
|
|
+ assert not args.bminf and args.quantization_bit_width is not None
|
|
|
|
+ # Quantize model before moving to GPU
|
|
|
|
+ model = quantize(model, args.quantization_bit_width)
|
|
|
|
+
|
|
# Load checkpoint
|
|
# Load checkpoint
|
|
torch.distributed.barrier()
|
|
torch.distributed.barrier()
|
|
start = time.time()
|
|
start = time.time()
|
|
@@ -59,7 +67,7 @@ def initialize_model_and_tokenizer(args):
|
|
with torch.cuda.device(args.device):
|
|
with torch.cuda.device(args.device):
|
|
model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
|
|
model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
|
|
else:
|
|
else:
|
|
- if args.quantization_bit_width is not None:
|
|
|
|
|
|
+ if args.quantization_bit_width is not None and not args.load_from_quantized_checkpoint:
|
|
# Quantize model before moving to GPU
|
|
# Quantize model before moving to GPU
|
|
model = quantize(model, args.quantization_bit_width)
|
|
model = quantize(model, args.quantization_bit_width)
|
|
model = model.to(args.device)
|
|
model = model.to(args.device)
|