|
@@ -7,7 +7,8 @@ from quantization import quantize
|
|
from SwissArmyTransformer import get_args, get_tokenizer
|
|
from SwissArmyTransformer import get_args, get_tokenizer
|
|
from SwissArmyTransformer.arguments import initialize_distributed
|
|
from SwissArmyTransformer.arguments import initialize_distributed
|
|
from SwissArmyTransformer.training import load_checkpoint
|
|
from SwissArmyTransformer.training import load_checkpoint
|
|
-from SwissArmyTransformer.model import GLM130B
|
|
|
|
|
|
+from SwissArmyTransformer.model import GLM130B, GLMModel
|
|
|
|
+from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin
|
|
|
|
|
|
|
|
|
|
def add_bminf_args(parser):
|
|
def add_bminf_args(parser):
|
|
@@ -31,6 +32,7 @@ def initialize(extra_args_provider):
|
|
add_bminf_args(parser)
|
|
add_bminf_args(parser)
|
|
add_quantization_args(parser)
|
|
add_quantization_args(parser)
|
|
GLM130B.add_model_specific_args(parser)
|
|
GLM130B.add_model_specific_args(parser)
|
|
|
|
+ GLMModel.add_model_specific_args(parser)
|
|
extra_args_provider(parser)
|
|
extra_args_provider(parser)
|
|
known, args_list = parser.parse_known_args()
|
|
known, args_list = parser.parse_known_args()
|
|
args = get_args(args_list)
|
|
args = get_args(args_list)
|
|
@@ -40,11 +42,32 @@ def initialize(extra_args_provider):
|
|
return args
|
|
return args
|
|
|
|
|
|
|
|
|
|
-def initialize_model_and_tokenizer(args):
|
|
|
|
- tokenizer = get_tokenizer(args)
|
|
|
|
|
|
+class SmallTokenizer:
|
|
|
|
+ def __init__(self, tokenizer):
|
|
|
|
+ self.tokenizer = tokenizer
|
|
|
|
+
|
|
|
|
+ def tokenize(self, text):
|
|
|
|
+ return self.tokenizer.EncodeAsIds(text).tokenization
|
|
|
|
+
|
|
|
|
+ def detokenize(self, ids):
|
|
|
|
+ return self.tokenizer.DecodeIds(ids)
|
|
|
|
+
|
|
|
|
+ def get_command(self, name):
|
|
|
|
+ map = {"[MASK]": "MASK", "[gMASK]": "gMASK", "[sMASK]": "sMASK"}
|
|
|
|
+ if name in map:
|
|
|
|
+ name = map[name]
|
|
|
|
+ return self.tokenizer.get_command(name).Id
|
|
|
|
|
|
|
|
+
|
|
|
|
+def initialize_model_and_tokenizer(args):
|
|
|
|
+ if args.tokenizer_type.startswith("glm_"):
|
|
|
|
+ tokenizer = SmallTokenizer(get_tokenizer(args))
|
|
|
|
+ tokenizer = get_tokenizer(args, outer_tokenizer=tokenizer)
|
|
|
|
+ else:
|
|
|
|
+ tokenizer = get_tokenizer(args)
|
|
# Initialize model
|
|
# Initialize model
|
|
- model = GLM130B(args).half()
|
|
|
|
|
|
+ model = GLMModel(args).half()
|
|
|
|
+ model.add_mixin('cached-autoregressive', CachedAutoregressiveMixin())
|
|
|
|
|
|
if args.from_quantized_checkpoint:
|
|
if args.from_quantized_checkpoint:
|
|
assert args.quantization_bit_width is not None
|
|
assert args.quantization_bit_width is not None
|
|
@@ -77,12 +100,12 @@ def initialize_model_and_tokenizer(args):
|
|
model.eval()
|
|
model.eval()
|
|
|
|
|
|
# generate rotary embedding cache
|
|
# generate rotary embedding cache
|
|
- with torch.no_grad():
|
|
|
|
- _, *_ = model(
|
|
|
|
- torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
|
|
|
|
- torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
|
|
|
|
- torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
|
|
|
|
- )
|
|
|
|
|
|
+ # with torch.no_grad():
|
|
|
|
+ # _, *_ = model(
|
|
|
|
+ # torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
|
|
|
|
+ # torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
|
|
|
|
+ # torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
|
|
|
|
+ # )
|
|
|
|
|
|
torch.distributed.barrier()
|
|
torch.distributed.barrier()
|
|
|
|
|