initialize.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import argparse
  2. import torch
  3. import time
  4. from quantization import quantize
  5. from SwissArmyTransformer import get_args, get_tokenizer
  6. from SwissArmyTransformer.arguments import initialize_distributed
  7. from SwissArmyTransformer.training import load_checkpoint
  8. from SwissArmyTransformer.model import GLM130B, GLMModel
  9. from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin
  10. def add_bminf_args(parser):
  11. """Arguments for BMInf"""
  12. group = parser.add_argument_group("BMInf")
  13. group.add_argument("--bminf", action="store_true", help="Use BMInf to support low resource evaluation")
  14. group.add_argument("--bminf-memory-limit", type=int, default=20, help="Max memory for model per GPU (in GB)")
  15. return parser
  16. def add_quantization_args(parser):
  17. group = parser.add_argument_group("Quantization")
  18. group.add_argument("--quantization-bit-width", type=int, default=None)
  19. group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
  20. def initialize(extra_args_provider):
  21. parser = argparse.ArgumentParser(add_help=False)
  22. add_bminf_args(parser)
  23. add_quantization_args(parser)
  24. GLM130B.add_model_specific_args(parser)
  25. GLMModel.add_model_specific_args(parser)
  26. extra_args_provider(parser)
  27. known, args_list = parser.parse_known_args()
  28. args = get_args(args_list)
  29. args = argparse.Namespace(**vars(args), **vars(known))
  30. args.do_train = False
  31. initialize_distributed(args)
  32. return args
  33. class SmallTokenizer:
  34. def __init__(self, tokenizer):
  35. self.tokenizer = tokenizer
  36. def tokenize(self, text):
  37. return self.tokenizer.EncodeAsIds(text).tokenization
  38. def detokenize(self, ids):
  39. return self.tokenizer.DecodeIds(ids)
  40. def get_command(self, name):
  41. map = {"[MASK]": "MASK", "[gMASK]": "gMASK", "[sMASK]": "sMASK"}
  42. if name in map:
  43. name = map[name]
  44. return self.tokenizer.get_command(name).Id
  45. def initialize_model_and_tokenizer(args):
  46. if args.tokenizer_type.startswith("glm_"):
  47. tokenizer = SmallTokenizer(get_tokenizer(args))
  48. tokenizer = get_tokenizer(args, outer_tokenizer=tokenizer)
  49. else:
  50. tokenizer = get_tokenizer(args)
  51. # Initialize model
  52. model = GLMModel(args).half()
  53. model.add_mixin('cached-autoregressive', CachedAutoregressiveMixin())
  54. if args.from_quantized_checkpoint:
  55. assert args.quantization_bit_width is not None
  56. # Quantize model before moving to GPU
  57. model = quantize(model, args.quantization_bit_width)
  58. # Load checkpoint
  59. torch.distributed.barrier()
  60. start = time.time()
  61. if args.load:
  62. load_checkpoint(model, args)
  63. torch.distributed.barrier()
  64. if torch.distributed.get_rank() == 0:
  65. print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
  66. if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
  67. # Quantize model before moving to GPU
  68. model = quantize(model, args.quantization_bit_width)
  69. if args.bminf:
  70. import bminf
  71. if torch.distributed.get_rank() == 0:
  72. print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
  73. with torch.cuda.device(args.device):
  74. model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
  75. else:
  76. model = model.to(args.device)
  77. torch.cuda.empty_cache()
  78. model.eval()
  79. # generate rotary embedding cache
  80. # with torch.no_grad():
  81. # _, *_ = model(
  82. # torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
  83. # torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
  84. # torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
  85. # )
  86. torch.distributed.barrier()
  87. return model, tokenizer