2
0

initialize.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import argparse
  2. import torch
  3. import time
  4. from SwissArmyTransformer import get_args, get_tokenizer
  5. from SwissArmyTransformer.arguments import initialize_distributed
  6. from SwissArmyTransformer.training import load_checkpoint
  7. from SwissArmyTransformer.model import GLM130B
  8. def add_bminf_args(parser):
  9. """Arguments for BMInf"""
  10. group = parser.add_argument_group("BMInf")
  11. group.add_argument("--bminf", action="store_true", help="Use BMInf to support low resource evaluation")
  12. group.add_argument("--bminf-memory-limit", type=int, default=20, help="Max memory for model per GPU (in GB)")
  13. return parser
  14. def initialize(extra_args_provider):
  15. parser = argparse.ArgumentParser(add_help=False)
  16. add_bminf_args(parser)
  17. GLM130B.add_model_specific_args(parser)
  18. extra_args_provider(parser)
  19. known, args_list = parser.parse_known_args()
  20. args = get_args(args_list)
  21. args = argparse.Namespace(**vars(args), **vars(known))
  22. args.do_train = False
  23. initialize_distributed(args)
  24. return args
  25. def initialize_model_and_tokenizer(args):
  26. tokenizer = get_tokenizer(args)
  27. model = GLM130B(args).half()
  28. if args.bminf:
  29. import bminf
  30. with torch.cuda.device(args.device):
  31. model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
  32. else:
  33. model = model.to(args.device)
  34. torch.distributed.barrier()
  35. start = time.time()
  36. load_checkpoint(model, args)
  37. torch.distributed.barrier()
  38. if torch.distributed.get_rank() == 0:
  39. print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
  40. model.eval()
  41. # generate rotary embedding cache
  42. with torch.no_grad():
  43. _, *_ = model(
  44. torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
  45. torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
  46. torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
  47. )
  48. torch.distributed.barrier()
  49. return model, tokenizer