convert-hf-to-ggml.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Convert HF models to ggml format
  2. #
  3. import sys
  4. import struct
  5. import json
  6. import torch
  7. import numpy as np
  8. import re
  9. import os
  10. import argparse
  11. from transformers import AutoModelForCausalLM
  12. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BloomForCausalLM
  13. # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
  14. def bytes_to_unicode():
  15. """
  16. Returns list of utf-8 byte and a corresponding list of unicode strings.
  17. The reversible bpe codes work on unicode strings.
  18. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  19. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  20. This is a signficant percentage of your normal, say, 32K bpe vocab.
  21. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  22. And avoids mapping to whitespace/control characters the bpe code barfs on.
  23. """
  24. bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
  25. cs = bs[:]
  26. n = 0
  27. for b in range(2**8):
  28. if b not in bs:
  29. bs.append(b)
  30. cs.append(2**8+n)
  31. n += 1
  32. cs = [chr(n) for n in cs]
  33. return dict(zip(bs, cs))
  34. parser = argparse.ArgumentParser(description='Convert starcoder HF model to GGML')
  35. parser.add_argument('model_name_or_path', type=str, help='Name of model on HF hub, or local model folder')
  36. parser.add_argument('--outfile', type=str, default='ggml-model.bin', help='Path of GGML file to write.')
  37. parser.add_argument('--use_f32', action="store_true", help='Save GGML file in fp32')
  38. args = parser.parse_args()
  39. # use 16-bit or 32-bit floats
  40. use_f16 = not args.use_f32
  41. fname_out = args.outfile
  42. fname_dir = os.path.dirname(fname_out)
  43. if fname_dir:
  44. os.makedirs(fname_dir, exist_ok=True)
  45. print("Loading model: ", args.model_name_or_path)
  46. tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
  47. config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
  48. hparams = config.to_dict()
  49. model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, config=config, torch_dtype=torch.float16 if use_f16 else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True)
  50. print("Model loaded: ", args.model_name_or_path)
  51. list_vars = model.state_dict()
  52. encoder = tokenizer.vocab
  53. # Add added_tokens (special tokens) to the encoder
  54. encoder.update(tokenizer.get_added_vocab())
  55. print(hparams)
  56. print("Saving ggml model to: ", fname_out)
  57. fout = open(fname_out, "wb")
  58. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  59. vocab_size = hparams["vocab_size"]
  60. fout.write(struct.pack("i", vocab_size))
  61. # fout.write(struct.pack("i", len(encoder)))
  62. fout.write(struct.pack("i", hparams["n_positions"]))
  63. fout.write(struct.pack("i", hparams["n_embd"]))
  64. fout.write(struct.pack("i", hparams["n_head"]))
  65. fout.write(struct.pack("i", hparams["n_layer"]))
  66. fout.write(struct.pack("i", use_f16))
  67. byte_encoder = bytes_to_unicode()
  68. byte_decoder = {v:k for k, v in byte_encoder.items()}
  69. fout.write(struct.pack("i", vocab_size))
  70. counter = 0
  71. # sort by value
  72. for key in sorted(encoder, key=encoder.get):
  73. text = bytearray([byte_decoder[c] for c in key])
  74. fout.write(struct.pack("i", len(text)))
  75. fout.write(text)
  76. counter += 1
  77. # TODO: Repeat last token until vocab_size
  78. while counter < vocab_size:
  79. fout.write(struct.pack("i", len(text)))
  80. fout.write(text)
  81. counter += 1
  82. # assert counter == config.vocab_size
  83. for name in list_vars.keys():
  84. data = list_vars[name].squeeze().numpy()
  85. print("Processing variable: " + name + " with shape: ", data.shape)
  86. # rename headers to keep compatibility
  87. if name == "transformer.ln_f.weight":
  88. name = "model/ln_f/g"
  89. elif name == "transformer.ln_f.bias":
  90. name = "model/ln_f/b"
  91. elif name == "transformer.wte.weight":
  92. name = "model/wte"
  93. elif name == "transformer.wpe.weight":
  94. name = "model/wpe"
  95. elif name == "lm_head.weight":
  96. name = "model/lm_head"
  97. elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
  98. i = re.findall("\d+", name)[0]
  99. name = f"model/h{i}/ln_1/g"
  100. elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
  101. i = re.findall("\d+", name)[0]
  102. name = f"model/h{i}/ln_1/b"
  103. elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
  104. i = re.findall("\d+", name)[0]
  105. name = f"model/h{i}/attn/c_attn/w"
  106. elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
  107. i = re.findall("\d+", name)[0]
  108. name = f"model/h{i}/attn/c_attn/b"
  109. elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
  110. i = re.findall("\d+", name)[0]
  111. name = f"model/h{i}/attn/c_proj/w"
  112. elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
  113. i = re.findall("\d+", name)[0]
  114. name = f"model/h{i}/attn/c_proj/b"
  115. elif re.match(r"transformer.h.\d+.ln_2.weight", name):
  116. i = re.findall("\d+", name)[0]
  117. name = f"model/h{i}/ln_2/g"
  118. elif re.match(r"transformer.h.\d+.ln_2.bias", name):
  119. i = re.findall("\d+", name)[0]
  120. name = f"model/h{i}/ln_2/b"
  121. elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
  122. i = re.findall("\d+", name)[0]
  123. name = f"model/h{i}/mlp/c_fc/w"
  124. elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
  125. i = re.findall("\d+", name)[0]
  126. name = f"model/h{i}/mlp/c_fc/b"
  127. elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
  128. i = re.findall("\d+", name)[0]
  129. name = f"model/h{i}/mlp/c_proj/w"
  130. elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
  131. i = re.findall("\d+", name)[0]
  132. name = f"model/h{i}/mlp/c_proj/b"
  133. else:
  134. print("Unrecognized variable name. %s", name)
  135. # we don't need these
  136. if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
  137. print(" Skipping variable: " + name)
  138. continue
  139. n_dims = len(data.shape);
  140. # ftype == 0 -> float32, ftype == 1 -> float16
  141. ftype = 0;
  142. if use_f16:
  143. if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or name[-2:] == "/w") and n_dims == 2:
  144. print(" Converting to float16")
  145. data = data.astype(np.float16)
  146. ftype = 1
  147. else:
  148. print(" Converting to float32")
  149. data = data.astype(np.float32)
  150. ftype = 0
  151. "model/h.*/attn/c_attn/w"
  152. "model/h.*/attn/c_proj/w"
  153. "model/h.*/mlp/c_fc/w"
  154. "model/h.*/mlp/c_proj/w"
  155. if name[-14:] == "/attn/c_attn/w" or name[-14:] == "/attn/c_attn/b":
  156. print(" Duplicate K,V heads to use MHA instead of MQA")
  157. embed_dim = hparams["n_embd"]
  158. head_dim = embed_dim // hparams["n_head"]
  159. # ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
  160. q, k ,v = np.split(data, (hparams["n_head"] * head_dim, (hparams["n_head"] + 1) * head_dim), axis=0)
  161. # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
  162. if len(k.shape) == 2:
  163. k = np.tile(k, (hparams["n_head"], 1))
  164. v = np.tile(v, (hparams["n_head"], 1))
  165. elif len(k.shape) == 1:
  166. k = np.tile(k, (hparams["n_head"]))
  167. v = np.tile(v, (hparams["n_head"]))
  168. # concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
  169. data = np.concatenate((q, k, v), axis=0)
  170. # header
  171. str = name.encode('utf-8')
  172. fout.write(struct.pack("iii", n_dims, len(str), ftype))
  173. for i in range(n_dims):
  174. fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
  175. fout.write(str);
  176. # data
  177. data.tofile(fout)
  178. fout.close()
  179. print("Done. Output file: " + fname_out)
  180. print("")