|
@@ -8,6 +8,7 @@ from SwissArmyTransformer import get_args, get_tokenizer
|
|
from SwissArmyTransformer.arguments import initialize_distributed
|
|
from SwissArmyTransformer.arguments import initialize_distributed
|
|
from SwissArmyTransformer.training import load_checkpoint
|
|
from SwissArmyTransformer.training import load_checkpoint
|
|
from SwissArmyTransformer.model import GLM130B
|
|
from SwissArmyTransformer.model import GLM130B
|
|
|
|
+from SwissArmyTransformer.mpu import get_model_parallel_world_size, get_model_parallel_rank, get_model_parallel_group
|
|
|
|
|
|
|
|
|
|
def add_bminf_args(parser):
|
|
def add_bminf_args(parser):
|
|
@@ -26,10 +27,21 @@ def add_quantization_args(parser):
|
|
group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
|
|
group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
|
|
|
|
|
|
|
|
|
|
|
|
+def add_initialization_args(parser):
|
|
|
|
+ group = parser.add_argument_group("Initialization")
|
|
|
|
+
|
|
|
|
+ group.add_argument(
|
|
|
|
+ "--sequential-initialization",
|
|
|
|
+ action="store_true",
|
|
|
|
+ help="Initialize sequentially in tensor parallel group (reduce CPU RAM for initialization)",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
def initialize(extra_args_provider):
|
|
def initialize(extra_args_provider):
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
add_bminf_args(parser)
|
|
add_bminf_args(parser)
|
|
add_quantization_args(parser)
|
|
add_quantization_args(parser)
|
|
|
|
+ add_initialization_args(parser)
|
|
GLM130B.add_model_specific_args(parser)
|
|
GLM130B.add_model_specific_args(parser)
|
|
extra_args_provider(parser)
|
|
extra_args_provider(parser)
|
|
known, args_list = parser.parse_known_args()
|
|
known, args_list = parser.parse_known_args()
|
|
@@ -43,35 +55,41 @@ def initialize(extra_args_provider):
|
|
def initialize_model_and_tokenizer(args):
|
|
def initialize_model_and_tokenizer(args):
|
|
tokenizer = get_tokenizer(args)
|
|
tokenizer = get_tokenizer(args)
|
|
|
|
|
|
- # Initialize model
|
|
|
|
- model = GLM130B(args).half()
|
|
|
|
-
|
|
|
|
- if args.from_quantized_checkpoint:
|
|
|
|
- assert args.quantization_bit_width is not None
|
|
|
|
- # Quantize model before moving to GPU
|
|
|
|
- model = quantize(model, args.quantization_bit_width)
|
|
|
|
-
|
|
|
|
- # Load checkpoint
|
|
|
|
torch.distributed.barrier()
|
|
torch.distributed.barrier()
|
|
start = time.time()
|
|
start = time.time()
|
|
- load_checkpoint(model, args)
|
|
|
|
- torch.distributed.barrier()
|
|
|
|
- if torch.distributed.get_rank() == 0:
|
|
|
|
- print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
|
|
|
|
|
|
|
|
- if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
|
|
|
|
- # Quantize model before moving to GPU
|
|
|
|
- model = quantize(model, args.quantization_bit_width)
|
|
|
|
|
|
+ for i in range(get_model_parallel_world_size()):
|
|
|
|
+ if get_model_parallel_rank() == i:
|
|
|
|
+ # Initialize model
|
|
|
|
+ model = GLM130B(args).half()
|
|
|
|
|
|
- if args.bminf:
|
|
|
|
- import bminf
|
|
|
|
|
|
+ if args.from_quantized_checkpoint:
|
|
|
|
+ assert args.quantization_bit_width is not None
|
|
|
|
+ # Quantize model before moving to GPU
|
|
|
|
+ model = quantize(model, args.quantization_bit_width)
|
|
|
|
|
|
- if torch.distributed.get_rank() == 0:
|
|
|
|
- print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
|
|
|
|
- with torch.cuda.device(args.device):
|
|
|
|
- model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
|
|
|
|
- else:
|
|
|
|
- model = model.to(args.device)
|
|
|
|
|
|
+ # Load checkpoint
|
|
|
|
+ load_checkpoint(model, args)
|
|
|
|
+
|
|
|
|
+ if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
|
|
|
|
+ # Quantize model before moving to GPU
|
|
|
|
+ model = quantize(model, args.quantization_bit_width)
|
|
|
|
+
|
|
|
|
+ if args.bminf:
|
|
|
|
+ import bminf
|
|
|
|
+
|
|
|
|
+ if torch.distributed.get_rank() == 0:
|
|
|
|
+ print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
|
|
|
|
+ with torch.cuda.device(args.device):
|
|
|
|
+ model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
|
|
|
|
+ else:
|
|
|
|
+ model = model.to(args.device)
|
|
|
|
+ if args.sequential_initialization:
|
|
|
|
+ torch.distributed.barrier(group=get_model_parallel_group())
|
|
|
|
+
|
|
|
|
+ torch.distributed.barrier()
|
|
|
|
+ if torch.distributed.get_rank() == 0:
|
|
|
|
+ print(f"> Model initialized in {time.time() - start:.1f}s")
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.empty_cache()
|
|
model.eval()
|
|
model.eval()
|