convert-pt-to-ggml.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Convert UnitY model from PyTorch to ggml format
  2. #
  3. # Usage: python3.8 /private/home/dnn/ggml/ggml/examples/unity/convert-pt-to-ggml.py /large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt /private/home/dnn/ggml/ggml/examples/unity/models/unity-large
  4. #
  5. import io
  6. import struct
  7. import sys
  8. from pathlib import Path
  9. import numpy as np
  10. import torch
  11. from convert_pt_states import generate_mapping
  12. if len(sys.argv) < 3:
  13. print("Usage: convert-pt-to-ggml.py model.pt dir-output [use-f32]\n")
  14. sys.exit(1)
  15. fname_inp = Path(sys.argv[1])
  16. dir_out = Path(sys.argv[2])
  17. # try to load PyTorch binary data
  18. try:
  19. model_bytes = open(fname_inp, "rb").read()
  20. with io.BytesIO(model_bytes) as fp:
  21. checkpoint = torch.load(fp, map_location="cpu")
  22. except Exception:
  23. print("Error: failed to load PyTorch model file:", fname_inp)
  24. sys.exit(1)
  25. hparams = {
  26. "n_text_vocab": 256064,
  27. "n_audio_enc_dim": 1024,
  28. "n_audio_enc_ffn_dim": 4096,
  29. "n_audio_enc_feat_dim": 160,
  30. "n_audio_enc_layer": 24,
  31. "n_audio_enc_head": 16,
  32. }
  33. print("hparams:", hparams)
  34. list_vars = checkpoint["model"]
  35. state_map = generate_mapping(list_vars)
  36. # output in the same directory as the model
  37. fname_out = dir_out / "ggml-model.bin"
  38. # use 16-bit or 32-bit floats
  39. use_f16 = True
  40. if len(sys.argv) > 4:
  41. use_f16 = False
  42. fname_out = dir_out / "ggml-model-f32.bin"
  43. fout = fname_out.open("wb")
  44. fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex
  45. for key in hparams.keys():
  46. fout.write(struct.pack("i", hparams[key]))
  47. fout.write(struct.pack("i", use_f16))
  48. exclude_list = []
  49. exclude_list += [
  50. f"encoder.w2v_encoder.w2v_model.encoder.layers.{i}.conv_module.batch_norm.num_batches_tracked"
  51. for i in range(24)
  52. ]
  53. for name in list_vars.keys():
  54. if (
  55. list_vars[name] is None
  56. or name in exclude_list
  57. or "adaptor" in name
  58. or "mask_emb" in name
  59. ):
  60. continue
  61. data = list_vars[name].squeeze().numpy()
  62. print("Processing variable: ", name, " with shape: ", data.shape)
  63. n_dims = len(data.shape)
  64. # TODO: Convert to fp16 when necessary!
  65. ftype = 0
  66. if name not in state_map:
  67. continue
  68. # header
  69. # if 'pos_bias' in name:
  70. # import pdb; pdb.set_trace()
  71. # print(data.shape)
  72. str_ = state_map[name].encode("utf-8")
  73. fout.write(struct.pack("iii", n_dims, len(str_), ftype))
  74. for i in range(n_dims):
  75. if ".layer_norm.weight" in name:
  76. print(data.shape)
  77. fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
  78. fout.write(str_)
  79. # data
  80. data.tofile(fout)
  81. fout.close()
  82. print("Done. Output file: ", fname_out)
  83. print("")