initialize.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import argparse
  2. import torch
  3. import time
  4. from quantization import quantize
  5. from SwissArmyTransformer import get_args, get_tokenizer
  6. from SwissArmyTransformer.arguments import initialize_distributed
  7. from SwissArmyTransformer.training import load_checkpoint
  8. from SwissArmyTransformer.model import GLM130B
  9. from SwissArmyTransformer.mpu import get_model_parallel_world_size, get_model_parallel_rank, get_model_parallel_group
  10. def add_bminf_args(parser):
  11. """Arguments for BMInf"""
  12. group = parser.add_argument_group("BMInf")
  13. group.add_argument("--bminf", action="store_true", help="Use BMInf to support low resource evaluation")
  14. group.add_argument("--bminf-memory-limit", type=int, default=20, help="Max memory for model per GPU (in GB)")
  15. return parser
  16. def add_quantization_args(parser):
  17. group = parser.add_argument_group("Quantization")
  18. group.add_argument("--quantization-bit-width", type=int, default=None)
  19. group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
  20. def add_initialization_args(parser):
  21. group = parser.add_argument_group("Initialization")
  22. group.add_argument(
  23. "--sequential-initialization",
  24. action="store_true",
  25. help="Initialize sequentially in tensor parallel group (reduce CPU RAM for initialization)",
  26. )
  27. def initialize(extra_args_provider):
  28. parser = argparse.ArgumentParser(add_help=False)
  29. add_bminf_args(parser)
  30. add_quantization_args(parser)
  31. add_initialization_args(parser)
  32. GLM130B.add_model_specific_args(parser)
  33. extra_args_provider(parser)
  34. known, args_list = parser.parse_known_args()
  35. args = get_args(args_list)
  36. args = argparse.Namespace(**vars(args), **vars(known))
  37. args.do_train = False
  38. initialize_distributed(args)
  39. return args
  40. def initialize_model_and_tokenizer(args):
  41. tokenizer = get_tokenizer(args)
  42. torch.distributed.barrier()
  43. start = time.time()
  44. for i in range(get_model_parallel_world_size()):
  45. if get_model_parallel_rank() == i:
  46. # Initialize model
  47. model = GLM130B(args).half()
  48. if args.from_quantized_checkpoint:
  49. assert args.quantization_bit_width is not None
  50. # Quantize model before moving to GPU
  51. model = quantize(model, args.quantization_bit_width)
  52. # Load checkpoint
  53. load_checkpoint(model, args)
  54. if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
  55. # Quantize model before moving to GPU
  56. model = quantize(model, args.quantization_bit_width)
  57. if args.bminf:
  58. import bminf
  59. if torch.distributed.get_rank() == 0:
  60. print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
  61. with torch.cuda.device(args.device):
  62. model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
  63. else:
  64. model = model.to(args.device)
  65. if args.sequential_initialization:
  66. torch.distributed.barrier(group=get_model_parallel_group())
  67. torch.distributed.barrier()
  68. if torch.distributed.get_rank() == 0:
  69. print(f"> Model initialized in {time.time() - start:.1f}s")
  70. torch.cuda.empty_cache()
  71. model.eval()
  72. # generate rotary embedding cache
  73. original_parallel_output = model.transformer.parallel_output
  74. model.transformer.parallel_output = True
  75. with torch.no_grad():
  76. _, *_ = model(
  77. torch.ones(1, args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64),
  78. torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1),
  79. torch.randn(
  80. 1,
  81. 1,
  82. args.max_sequence_length,
  83. args.max_sequence_length,
  84. device=torch.cuda.current_device(),
  85. )
  86. < 0.5,
  87. )
  88. model.transformer.parallel_output = original_parallel_output
  89. torch.distributed.barrier()
  90. return model, tokenizer