convert-h5-to-ggml.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Convert MNIS h5 transformer model to ggml format
  2. #
  3. # Load the (state_dict) saved model using PyTorch
  4. # Iterate over all variables and write them to a binary file.
  5. #
  6. # For each variable, write the following:
  7. # - Number of dimensions (int)
  8. # - Name length (int)
  9. # - Dimensions (int[n_dims])
  10. # - Name (char[name_length])
  11. # - Data (float[n_dims])
  12. #
  13. # At the start of the ggml file we write the model parameters
  14. import sys
  15. import struct
  16. import json
  17. import numpy as np
  18. import re
  19. import torch
  20. import torch.nn as nn
  21. import torchvision.datasets as dsets
  22. import torchvision.transforms as transforms
  23. from torch.autograd import Variable
  24. if len(sys.argv) != 2:
  25. print("Usage: convert-h5-to-ggml.py model\n")
  26. sys.exit(1)
  27. state_dict_file = sys.argv[1]
  28. fname_out = "models/mnist/ggml-model-f32.bin"
  29. state_dict = torch.load(state_dict_file, map_location=torch.device('cpu'))
  30. #print (model)
  31. list_vars = state_dict
  32. print (list_vars)
  33. fout = open(fname_out, "wb")
  34. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  35. for name in list_vars.keys():
  36. data = list_vars[name].squeeze().numpy()
  37. print("Processing variable: " + name + " with shape: ", data.shape)
  38. n_dims = len(data.shape);
  39. fout.write(struct.pack("i", n_dims))
  40. data = data.astype(np.float32)
  41. for i in range(n_dims):
  42. fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
  43. # data
  44. data.tofile(fout)
  45. fout.close()
  46. print("Done. Output file: " + fname_out)
  47. print("")