convert-pt-to-ggml.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Convert UnitY model from PyTorch to ggml format
  2. #
  3. # Usage: python3.8 /private/home/dnn/ggml/ggml/examples/unity/convert-pt-to-ggml.py /large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt /private/home/dnn/ggml/ggml/examples/unity/models/unity-large
  4. #
  5. import io
  6. import sys
  7. import struct
  8. import torch
  9. import numpy as np
  10. from pathlib import Path
  11. from convert_pt_states import generate_mapping
  12. if len(sys.argv) < 3:
  13. print("Usage: convert-pt-to-ggml.py model.pt dir-output [use-f32]\n")
  14. sys.exit(1)
  15. fname_inp = Path(sys.argv[1])
  16. dir_out = Path(sys.argv[2])
  17. # try to load PyTorch binary data
  18. try:
  19. model_bytes = open(fname_inp, "rb").read()
  20. with io.BytesIO(model_bytes) as fp:
  21. checkpoint = torch.load(fp, map_location="cpu")
  22. except Exception:
  23. print("Error: failed to load PyTorch model file:" , fname_inp)
  24. sys.exit(1)
  25. hparams = {"n_text_vocab": 256064, "n_audio_enc_dim": 1024, "n_audio_enc_ffn_dim": 4096, "n_audio_enc_feat_dim": 160, "n_audio_enc_layer": 24, "n_audio_enc_head": 16}
  26. print("hparams:", hparams)
  27. list_vars = checkpoint["model"]
  28. state_map = generate_mapping(list_vars)
  29. # output in the same directory as the model
  30. fname_out = dir_out / "ggml-model.bin"
  31. # use 16-bit or 32-bit floats
  32. use_f16 = True
  33. if len(sys.argv) > 4:
  34. use_f16 = False
  35. fname_out = dir_out / "ggml-model-f32.bin"
  36. fout = fname_out.open("wb")
  37. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  38. for key in hparams.keys():
  39. fout.write(struct.pack("i", hparams[key]))
  40. fout.write(struct.pack("i", use_f16))
  41. exclude_list = []
  42. exclude_list += [f"encoder.w2v_encoder.w2v_model.encoder.layers.{i}.conv_module.batch_norm.num_batches_tracked" for i in range(24)]
  43. for name in list_vars.keys():
  44. if list_vars[name] is None or name in exclude_list or "adaptor" in name:
  45. continue
  46. data = list_vars[name].squeeze().numpy()
  47. print("Processing variable: " , name , " with shape: ", data.shape)
  48. n_dims = len(data.shape)
  49. # TODO: Convert to fp16 when necessary!
  50. ftype = 0
  51. if name not in state_map:
  52. continue
  53. # header
  54. # if 'pos_bias' in name:
  55. # import pdb; pdb.set_trace()
  56. # print(data.shape)
  57. str_ = state_map[name].encode('utf-8')
  58. fout.write(struct.pack("iii", n_dims, len(str_), ftype))
  59. for i in range(n_dims):
  60. fout.write(struct.pack("i", data.shape[i]))
  61. fout.write(str_)
  62. # data
  63. data.tofile(fout)
  64. fout.close()
  65. print("Done. Output file: " , fname_out)
  66. print("")