convert-h5-to-ggml.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import os
  2. import struct
  3. import sys
  4. import torch
  5. from transformers import AutoConfig, AutoTokenizer
  6. # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
  7. def bytes_to_unicode():
  8. """
  9. Returns list of utf-8 byte and a corresponding list of unicode strings.
  10. The reversible bpe codes work on unicode strings.
  11. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  12. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  13. This is a signficant percentage of your normal, say, 32K bpe vocab.
  14. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  15. And avoids mapping to whitespace/control characters the bpe code barfs on.
  16. """
  17. bs = (
  18. list(range(ord("!"), ord("~") + 1))
  19. + list(range(ord("¡"), ord("¬") + 1))
  20. + list(range(ord("®"), ord("ÿ") + 1))
  21. )
  22. cs = bs[:]
  23. n = 0
  24. for b in range(2**8):
  25. if b not in bs:
  26. bs.append(b)
  27. cs.append(2**8 + n)
  28. n += 1
  29. cs = [chr(n) for n in cs]
  30. return dict(zip(bs, cs))
  31. def count_model_parts(dir_model: str) -> int:
  32. """Returns the number of model parts in the model directory."""
  33. num_parts = 0
  34. for filename in os.listdir(dir_model):
  35. if filename.startswith("pytorch_model-"):
  36. num_parts += 1
  37. if num_parts > 0:
  38. print(f"Found {num_parts} model parts in {dir_model}")
  39. return num_parts
  40. if len(sys.argv) < 3:
  41. print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
  42. print(" ftype == 0 -> float32")
  43. print(" ftype == 1 -> float16")
  44. sys.exit(1)
  45. # output in the same directory as the model
  46. dir_model = sys.argv[1]
  47. # get number of model parts
  48. num_parts = count_model_parts(dir_model)
  49. # possible data types
  50. # ftype == 0 -> float32
  51. # ftype == 1 -> float16
  52. #
  53. # map from ftype to string
  54. ftype_str = ["f32", "f16"]
  55. ftype = 1
  56. if len(sys.argv) > 2:
  57. ftype = int(sys.argv[2])
  58. if ftype < 0 or ftype > 1:
  59. print("Invalid ftype: " + str(ftype))
  60. sys.exit(1)
  61. fname_out = dir_model + "/ggml-model-" + ftype_str[ftype] + ".bin"
  62. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  63. config = AutoConfig.from_pretrained(dir_model, trust_remote_code=True)
  64. hparams = config.to_dict()
  65. fout = open(fname_out, "wb")
  66. fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex
  67. fout.write(struct.pack("i", hparams["d_model"]))
  68. fout.write(struct.pack("i", hparams["max_seq_len"]))
  69. fout.write(struct.pack("i", hparams["n_heads"]))
  70. fout.write(struct.pack("i", hparams["n_layers"]))
  71. fout.write(struct.pack("i", hparams["vocab_size"]))
  72. fout.write(struct.pack("f", hparams["attn_config"]["alibi_bias_max"]))
  73. fout.write(struct.pack("f", hparams["attn_config"]["clip_qkv"] or 0.0))
  74. fout.write(struct.pack("i", ftype))
  75. vocab_size = hparams["vocab_size"]
  76. encoder = tokenizer.vocab
  77. # Add added_tokens (special tokens) to the encoder
  78. encoder.update(tokenizer.get_added_vocab())
  79. byte_encoder = bytes_to_unicode()
  80. byte_decoder = {v: k for k, v in byte_encoder.items()}
  81. counter = 0
  82. # sort by value
  83. for key in sorted(encoder, key=encoder.get):
  84. # workaround for key error when c not found
  85. text = ""
  86. for c in key:
  87. if c not in byte_decoder:
  88. text += c
  89. else:
  90. text += chr(byte_decoder[c])
  91. text = bytearray(text, encoding="utf-8")
  92. fout.write(struct.pack("i", len(text)))
  93. fout.write(text)
  94. counter += 1
  95. # Repeat last token until vocab_size
  96. while counter < vocab_size:
  97. fout.write(struct.pack("i", len(text)))
  98. fout.write(text)
  99. counter += 1
  100. if num_parts == 0:
  101. part_names = ("pytorch_model.bin",)
  102. else:
  103. part_names = (
  104. f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
  105. )
  106. for part_name in part_names:
  107. print(f"\n* Loading part: {part_name}")
  108. model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
  109. for name in model_part.keys():
  110. data = model_part[name].squeeze()
  111. n_dims = len(data.shape)
  112. # ftype == 0 -> float32, ftype == 1 -> float16
  113. # default type is fp32
  114. ftype_cur = 0
  115. if ftype == 1 and name[-7:] == ".weight" and n_dims > 1:
  116. ftype_cur = 1
  117. data = data.to(dtype=torch.float16 if ftype_cur == 1 else torch.float32).numpy()
  118. print(
  119. "Processing variable: " + name + " with shape: ",
  120. data.shape,
  121. "->",
  122. data.dtype,
  123. )
  124. # header
  125. str = name.encode("utf-8")
  126. fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
  127. for i in range(n_dims):
  128. fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
  129. fout.write(str)
  130. # data
  131. data.tofile(fout)
  132. # release memory
  133. del model_part
  134. fout.close()
  135. print("Done. Output file: " + fname_out)
  136. print("")