convert-pth-to-ggml.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Convert a SAM model checkpoint to a ggml compatible file
  2. #
  3. import sys
  4. import torch
  5. import struct
  6. import numpy as np
  7. if len(sys.argv) < 3:
  8. print("Usage: convert-pth-to-ggml.py file-model dir-output [ftype]\n")
  9. print(" ftype == 0 -> float32")
  10. print(" ftype == 1 -> float16")
  11. sys.exit(1)
  12. # output in the same directory as the model
  13. fname_model = sys.argv[1]
  14. dir_out = sys.argv[2]
  15. fname_out = dir_out + "/ggml-model.bin"
  16. # possible data types
  17. # ftype == 0 -> float32
  18. # ftype == 1 -> float16
  19. #
  20. # map from ftype to string
  21. ftype_str = ["f32", "f16"]
  22. ftype = 1
  23. if len(sys.argv) > 3:
  24. ftype = int(sys.argv[3])
  25. if ftype < 0 or ftype > 1:
  26. print("Invalid ftype: " + str(ftype))
  27. sys.exit(1)
  28. fname_out = fname_out.replace(".bin", "-" + ftype_str[ftype] + ".bin")
  29. # Default params are set to sam_vit_b checkpoint
  30. n_enc_state = 768
  31. n_enc_layers = 12
  32. n_enc_heads = 12
  33. n_enc_out_chans = 256
  34. n_pt_embd = 4
  35. model = torch.load(fname_model, map_location="cpu")
  36. for k, v in model.items():
  37. print(k, v.shape)
  38. if k == "image_encoder.blocks.0.norm1.weight":
  39. n_enc_state = v.shape[0]
  40. if n_enc_state == 1024: # sam_vit_l
  41. n_enc_layers = 24
  42. n_enc_heads = 16
  43. elif n_enc_state == 1280: # sam_vit_h
  44. n_enc_layers = 32
  45. n_enc_heads = 16
  46. hparams = {
  47. "n_enc_state": n_enc_state,
  48. "n_enc_layers": n_enc_layers,
  49. "n_enc_heads": n_enc_heads,
  50. "n_enc_out_chans": n_enc_out_chans,
  51. "n_pt_embd": n_pt_embd,
  52. }
  53. print(hparams)
  54. for k, v in model.items():
  55. print(k, v.shape)
  56. #exit()
  57. #code.interact(local=locals())
  58. fout = open(fname_out, "wb")
  59. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  60. fout.write(struct.pack("i", hparams["n_enc_state"]))
  61. fout.write(struct.pack("i", hparams["n_enc_layers"]))
  62. fout.write(struct.pack("i", hparams["n_enc_heads"]))
  63. fout.write(struct.pack("i", hparams["n_enc_out_chans"]))
  64. fout.write(struct.pack("i", hparams["n_pt_embd"]))
  65. fout.write(struct.pack("i", ftype))
  66. for k, v in model.items():
  67. name = k
  68. shape = v.shape
  69. if name[:19] == "prompt_encoder.mask":
  70. continue
  71. print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
  72. #data = tf.train.load_variable(dir_model, name).squeeze()
  73. #data = v.numpy().squeeze()
  74. data = v.numpy()
  75. n_dims = len(data.shape)
  76. # for efficiency - transpose some matrices
  77. # "model/h.*/attn/c_attn/w"
  78. # "model/h.*/attn/c_proj/w"
  79. # "model/h.*/mlp/c_fc/w"
  80. # "model/h.*/mlp/c_proj/w"
  81. #if name[-14:] == "/attn/c_attn/w" or \
  82. # name[-14:] == "/attn/c_proj/w" or \
  83. # name[-11:] == "/mlp/c_fc/w" or \
  84. # name[-13:] == "/mlp/c_proj/w":
  85. # print(" Transposing")
  86. # data = data.transpose()
  87. dshape = data.shape
  88. # default type is fp16
  89. ftype_cur = 1
  90. if ftype == 0 or n_dims == 1 or \
  91. name == "image_encoder.pos_embed" or \
  92. name.startswith("prompt_encoder") or \
  93. name.startswith("mask_decoder.iou_token") or \
  94. name.startswith("mask_decoder.mask_tokens"):
  95. print(" Converting to float32")
  96. data = data.astype(np.float32)
  97. ftype_cur = 0
  98. else:
  99. print(" Converting to float16")
  100. data = data.astype(np.float16)
  101. # reshape the 1D bias into a 4D tensor so we can use ggml_repeat
  102. # keep it in F32 since the data is small
  103. if name == "image_encoder.patch_embed.proj.bias":
  104. data = data.reshape(1, data.shape[0], 1, 1)
  105. n_dims = len(data.shape)
  106. dshape = data.shape
  107. print(" New shape: ", dshape)
  108. # header
  109. str = name.encode('utf-8')
  110. fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
  111. for i in range(n_dims):
  112. fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
  113. fout.write(str)
  114. # data
  115. data.tofile(fout)
  116. fout.close()
  117. print("Done. Output file: " + fname_out)
  118. print("")