|
@@ -1,22 +1,23 @@
|
|
# Convert UnitY model from PyTorch to ggml format
|
|
# Convert UnitY model from PyTorch to ggml format
|
|
#
|
|
#
|
|
# Usage: python3.8 /private/home/dnn/ggml/ggml/examples/unity/convert-pt-to-ggml.py /large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt /private/home/dnn/ggml/ggml/examples/unity/models/unity-large
|
|
# Usage: python3.8 /private/home/dnn/ggml/ggml/examples/unity/convert-pt-to-ggml.py /large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt /private/home/dnn/ggml/ggml/examples/unity/models/unity-large
|
|
-#
|
|
|
|
|
|
+#
|
|
import io
|
|
import io
|
|
-import sys
|
|
|
|
import struct
|
|
import struct
|
|
-import torch
|
|
|
|
-import numpy as np
|
|
|
|
|
|
+import sys
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from convert_pt_states import generate_mapping
|
|
|
|
|
|
|
|
|
|
+import numpy as np
|
|
|
|
+import torch
|
|
|
|
+
|
|
|
|
+from convert_pt_states import generate_mapping
|
|
|
|
|
|
if len(sys.argv) < 3:
|
|
if len(sys.argv) < 3:
|
|
print("Usage: convert-pt-to-ggml.py model.pt dir-output [use-f32]\n")
|
|
print("Usage: convert-pt-to-ggml.py model.pt dir-output [use-f32]\n")
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
|
|
|
-fname_inp = Path(sys.argv[1])
|
|
|
|
-dir_out = Path(sys.argv[2])
|
|
|
|
|
|
+fname_inp = Path(sys.argv[1])
|
|
|
|
+dir_out = Path(sys.argv[2])
|
|
|
|
|
|
# try to load PyTorch binary data
|
|
# try to load PyTorch binary data
|
|
try:
|
|
try:
|
|
@@ -24,10 +25,17 @@ try:
|
|
with io.BytesIO(model_bytes) as fp:
|
|
with io.BytesIO(model_bytes) as fp:
|
|
checkpoint = torch.load(fp, map_location="cpu")
|
|
checkpoint = torch.load(fp, map_location="cpu")
|
|
except Exception:
|
|
except Exception:
|
|
- print("Error: failed to load PyTorch model file:" , fname_inp)
|
|
|
|
|
|
+ print("Error: failed to load PyTorch model file:", fname_inp)
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
|
|
|
-hparams = {"n_text_vocab": 256064, "n_audio_enc_dim": 1024, "n_audio_enc_ffn_dim": 4096, "n_audio_enc_feat_dim": 160, "n_audio_enc_layer": 24, "n_audio_enc_head": 16}
|
|
|
|
|
|
+hparams = {
|
|
|
|
+ "n_text_vocab": 256064,
|
|
|
|
+ "n_audio_enc_dim": 1024,
|
|
|
|
+ "n_audio_enc_ffn_dim": 4096,
|
|
|
|
+ "n_audio_enc_feat_dim": 160,
|
|
|
|
+ "n_audio_enc_layer": 24,
|
|
|
|
+ "n_audio_enc_head": 16,
|
|
|
|
+}
|
|
print("hparams:", hparams)
|
|
print("hparams:", hparams)
|
|
|
|
|
|
list_vars = checkpoint["model"]
|
|
list_vars = checkpoint["model"]
|
|
@@ -44,19 +52,27 @@ if len(sys.argv) > 4:
|
|
|
|
|
|
fout = fname_out.open("wb")
|
|
fout = fname_out.open("wb")
|
|
|
|
|
|
-fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
|
|
|
|
|
|
+fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex
|
|
for key in hparams.keys():
|
|
for key in hparams.keys():
|
|
fout.write(struct.pack("i", hparams[key]))
|
|
fout.write(struct.pack("i", hparams[key]))
|
|
fout.write(struct.pack("i", use_f16))
|
|
fout.write(struct.pack("i", use_f16))
|
|
|
|
|
|
exclude_list = []
|
|
exclude_list = []
|
|
-exclude_list += [f"encoder.w2v_encoder.w2v_model.encoder.layers.{i}.conv_module.batch_norm.num_batches_tracked" for i in range(24)]
|
|
|
|
|
|
+exclude_list += [
|
|
|
|
+ f"encoder.w2v_encoder.w2v_model.encoder.layers.{i}.conv_module.batch_norm.num_batches_tracked"
|
|
|
|
+ for i in range(24)
|
|
|
|
+]
|
|
|
|
|
|
for name in list_vars.keys():
|
|
for name in list_vars.keys():
|
|
- if list_vars[name] is None or name in exclude_list or "adaptor" in name or "mask_emb" in name:
|
|
|
|
|
|
+ if (
|
|
|
|
+ list_vars[name] is None
|
|
|
|
+ or name in exclude_list
|
|
|
|
+ or "adaptor" in name
|
|
|
|
+ or "mask_emb" in name
|
|
|
|
+ ):
|
|
continue
|
|
continue
|
|
data = list_vars[name].squeeze().numpy()
|
|
data = list_vars[name].squeeze().numpy()
|
|
- print("Processing variable: " , name , " with shape: ", data.shape)
|
|
|
|
|
|
+ print("Processing variable: ", name, " with shape: ", data.shape)
|
|
|
|
|
|
n_dims = len(data.shape)
|
|
n_dims = len(data.shape)
|
|
|
|
|
|
@@ -68,12 +84,12 @@ for name in list_vars.keys():
|
|
# if 'pos_bias' in name:
|
|
# if 'pos_bias' in name:
|
|
# import pdb; pdb.set_trace()
|
|
# import pdb; pdb.set_trace()
|
|
# print(data.shape)
|
|
# print(data.shape)
|
|
- str_ = state_map[name].encode('utf-8')
|
|
|
|
|
|
+ str_ = state_map[name].encode("utf-8")
|
|
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
|
|
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
|
|
for i in range(n_dims):
|
|
for i in range(n_dims):
|
|
- if '.layer_norm.weight' in name:
|
|
|
|
|
|
+ if ".layer_norm.weight" in name:
|
|
print(data.shape)
|
|
print(data.shape)
|
|
- fout.write(struct.pack("i", data.shape[n_dims-1-i]))
|
|
|
|
|
|
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
|
fout.write(str_)
|
|
fout.write(str_)
|
|
|
|
|
|
# data
|
|
# data
|
|
@@ -81,6 +97,5 @@ for name in list_vars.keys():
|
|
|
|
|
|
fout.close()
|
|
fout.close()
|
|
|
|
|
|
-print("Done. Output file: " , fname_out)
|
|
|
|
|
|
+print("Done. Output file: ", fname_out)
|
|
print("")
|
|
print("")
|
|
-
|
|
|