convert-h5-to-ggml.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import sys
  2. import struct
  3. import json
  4. import numpy as np
  5. from transformers import AutoModelForCausalLM, AutoTokenizer
  6. if len(sys.argv) < 3:
  7. print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
  8. print(" ftype == 0 -> float32")
  9. print(" ftype == 1 -> float16")
  10. sys.exit(1)
  11. # output in the same directory as the model
  12. dir_model = sys.argv[1]
  13. fname_out = sys.argv[1] + "/ggml-model.bin"
  14. with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
  15. hparams = json.load(f)
  16. # possible data types
  17. # ftype == 0 -> float32
  18. # ftype == 1 -> float16
  19. #
  20. # map from ftype to string
  21. ftype_str = ["f32", "f16"]
  22. ftype = 1
  23. if len(sys.argv) > 2:
  24. ftype = int(sys.argv[2])
  25. if ftype < 0 or ftype > 1:
  26. print("Invalid ftype: " + str(ftype))
  27. sys.exit(1)
  28. fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
  29. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  30. model = AutoModelForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)
  31. list_vars = model.state_dict()
  32. for name in list_vars.keys():
  33. print(name, list_vars[name].shape, list_vars[name].dtype)
  34. fout = open(fname_out, "wb")
  35. print(hparams)
  36. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  37. fout.write(struct.pack("i", hparams["vocab_size"]))
  38. fout.write(struct.pack("i", hparams["max_position_embeddings"]))
  39. fout.write(struct.pack("i", hparams["hidden_size"]))
  40. fout.write(struct.pack("i", hparams["num_attention_heads"]))
  41. fout.write(struct.pack("i", hparams["num_hidden_layers"]))
  42. fout.write(struct.pack("i", int(hparams["rotary_pct"]*(hparams["hidden_size"]//hparams["num_attention_heads"]))))
  43. fout.write(struct.pack("i", hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True))
  44. fout.write(struct.pack("i", ftype))
  45. # TODO: temporary hack to not deal with implementing the tokenizer
  46. for i in range(hparams["vocab_size"]):
  47. text = tokenizer.decode([i]).encode('utf-8')
  48. fout.write(struct.pack("i", len(text)))
  49. fout.write(text)
  50. for name in list_vars.keys():
  51. data = list_vars[name].squeeze().numpy()
  52. print("Processing variable: " + name + " with shape: ", data.shape)
  53. # we don't need these
  54. if name.endswith(".attention.masked_bias") or \
  55. name.endswith(".attention.bias") or \
  56. name.endswith(".attention.rotary_emb.inv_freq"):
  57. print(" Skipping variable: " + name)
  58. continue
  59. n_dims = len(data.shape)
  60. # ftype == 0 -> float32, ftype == 1 -> float16
  61. ftype_cur = 0
  62. if ftype != 0:
  63. if name[-7:] == ".weight" and n_dims == 2:
  64. print(" Converting to float16")
  65. data = data.astype(np.float16)
  66. ftype_cur = 1
  67. else:
  68. print(" Converting to float32")
  69. data = data.astype(np.float32)
  70. ftype_cur = 0
  71. else:
  72. if data.dtype != np.float32:
  73. print(" Converting to float32")
  74. data = data.astype(np.float32)
  75. ftype_cur = 0
  76. # header
  77. str = name.encode('utf-8')
  78. fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
  79. for i in range(n_dims):
  80. fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
  81. fout.write(str)
  82. # data
  83. data.tofile(fout)
  84. fout.close()
  85. print("Done. Output file: " + fname_out)
  86. print("")