Kaynağa Gözat

Add sequential initialization

Sengxian 2 yıl önce
ebeveyn
işleme
373fb174df
1 değiştirilmiş dosya ile 42 ekleme ve 24 silme
  1. 42 24
      initialize.py

+ 42 - 24
initialize.py

@@ -8,6 +8,7 @@ from SwissArmyTransformer import get_args, get_tokenizer
 from SwissArmyTransformer.arguments import initialize_distributed
 from SwissArmyTransformer.training import load_checkpoint
 from SwissArmyTransformer.model import GLM130B
+from SwissArmyTransformer.mpu import get_model_parallel_world_size, get_model_parallel_rank, get_model_parallel_group
 
 
 def add_bminf_args(parser):
@@ -26,10 +27,21 @@ def add_quantization_args(parser):
     group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
 
 
+def add_initialization_args(parser):
+    group = parser.add_argument_group("Initialization")
+
+    group.add_argument(
+        "--sequential-initialization",
+        action="store_true",
+        help="Initialize sequentially in tensor parallel group (reduce CPU RAM for initialization)",
+    )
+
+
 def initialize(extra_args_provider):
     parser = argparse.ArgumentParser(add_help=False)
     add_bminf_args(parser)
     add_quantization_args(parser)
+    add_initialization_args(parser)
     GLM130B.add_model_specific_args(parser)
     extra_args_provider(parser)
     known, args_list = parser.parse_known_args()
@@ -43,35 +55,41 @@ def initialize(extra_args_provider):
 def initialize_model_and_tokenizer(args):
     tokenizer = get_tokenizer(args)
 
-    # Initialize model
-    model = GLM130B(args).half()
-
-    if args.from_quantized_checkpoint:
-        assert args.quantization_bit_width is not None
-        # Quantize model before moving to GPU
-        model = quantize(model, args.quantization_bit_width)
-
-    # Load checkpoint
     torch.distributed.barrier()
     start = time.time()
-    load_checkpoint(model, args)
-    torch.distributed.barrier()
-    if torch.distributed.get_rank() == 0:
-        print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
 
-    if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
-        # Quantize model before moving to GPU
-        model = quantize(model, args.quantization_bit_width)
+    for i in range(get_model_parallel_world_size()):
+        if get_model_parallel_rank() == i:
+            # Initialize model
+            model = GLM130B(args).half()
 
-    if args.bminf:
-        import bminf
+            if args.from_quantized_checkpoint:
+                assert args.quantization_bit_width is not None
+                # Quantize model before moving to GPU
+                model = quantize(model, args.quantization_bit_width)
 
-        if torch.distributed.get_rank() == 0:
-            print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
-        with torch.cuda.device(args.device):
-            model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
-    else:
-        model = model.to(args.device)
+            # Load checkpoint
+            load_checkpoint(model, args)
+
+            if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
+                # Quantize model before moving to GPU
+                model = quantize(model, args.quantization_bit_width)
+
+            if args.bminf:
+                import bminf
+
+                if torch.distributed.get_rank() == 0:
+                    print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
+                with torch.cuda.device(args.device):
+                    model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
+            else:
+                model = model.to(args.device)
+        if args.sequential_initialization:
+            torch.distributed.barrier(group=get_model_parallel_group())
+
+    torch.distributed.barrier()
+    if torch.distributed.get_rank() == 0:
+        print(f"> Model initialized in {time.time() - start:.1f}s")
 
     torch.cuda.empty_cache()
     model.eval()