2
0

initialize.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import argparse
  2. import torch
  3. import time
  4. from quantization import quantize
  5. from SwissArmyTransformer import get_args, get_tokenizer
  6. from SwissArmyTransformer.arguments import initialize_distributed
  7. from SwissArmyTransformer.training import load_checkpoint
  8. from SwissArmyTransformer.model import GLM130B
  9. def add_bminf_args(parser):
  10. """Arguments for BMInf"""
  11. group = parser.add_argument_group("BMInf")
  12. group.add_argument("--bminf", action="store_true", help="Use BMInf to support low resource evaluation")
  13. group.add_argument("--bminf-memory-limit", type=int, default=20, help="Max memory for model per GPU (in GB)")
  14. return parser
  15. def add_quantization_args(parser):
  16. group = parser.add_argument_group("Quantization")
  17. group.add_argument("--quantization-bit-width", type=int, default=None)
  18. group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
  19. def initialize(extra_args_provider):
  20. parser = argparse.ArgumentParser(add_help=False)
  21. add_bminf_args(parser)
  22. add_quantization_args(parser)
  23. GLM130B.add_model_specific_args(parser)
  24. extra_args_provider(parser)
  25. known, args_list = parser.parse_known_args()
  26. args = get_args(args_list)
  27. args = argparse.Namespace(**vars(args), **vars(known))
  28. args.do_train = False
  29. initialize_distributed(args)
  30. return args
  31. def initialize_model_and_tokenizer(args):
  32. tokenizer = get_tokenizer(args)
  33. # Initialize model
  34. model = GLM130B(args).half()
  35. if args.from_quantized_checkpoint:
  36. assert args.quantization_bit_width is not None
  37. # Quantize model before moving to GPU
  38. model = quantize(model, args.quantization_bit_width)
  39. # Load checkpoint
  40. torch.distributed.barrier()
  41. start = time.time()
  42. load_checkpoint(model, args)
  43. torch.distributed.barrier()
  44. if torch.distributed.get_rank() == 0:
  45. print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
  46. if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
  47. # Quantize model before moving to GPU
  48. model = quantize(model, args.quantization_bit_width)
  49. if args.bminf:
  50. import bminf
  51. if torch.distributed.get_rank() == 0:
  52. print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
  53. with torch.cuda.device(args.device):
  54. model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
  55. else:
  56. model = model.to(args.device)
  57. torch.cuda.empty_cache()
  58. model.eval()
  59. # generate rotary embedding cache
  60. with torch.no_grad():
  61. _, *_ = model(
  62. torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
  63. torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
  64. torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
  65. )
  66. torch.distributed.barrier()
  67. return model, tokenizer