소스 검색

black + isort

Guillaume Wenzek 1 년 전
부모
커밋
d6425f84b3
2개의 변경된 파일57개의 추가작업 그리고 33개의 파일을 삭제
  1. 33 18
      ggml/examples/unity/convert-pt-to-ggml.py
  2. 24 15
      ggml/examples/unity/convert_pt_states.py

+ 33 - 18
ggml/examples/unity/convert-pt-to-ggml.py

@@ -1,22 +1,23 @@
 # Convert UnitY model from PyTorch to ggml format
 #
 # Usage: python3.8 /private/home/dnn/ggml/ggml/examples/unity/convert-pt-to-ggml.py /large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt /private/home/dnn/ggml/ggml/examples/unity/models/unity-large
-# 
+#
 import io
-import sys
 import struct
-import torch
-import numpy as np
+import sys
 from pathlib import Path
-from convert_pt_states import generate_mapping
 
+import numpy as np
+import torch
+
+from convert_pt_states import generate_mapping
 
 if len(sys.argv) < 3:
     print("Usage: convert-pt-to-ggml.py model.pt dir-output [use-f32]\n")
     sys.exit(1)
 
-fname_inp   = Path(sys.argv[1])
-dir_out     = Path(sys.argv[2])
+fname_inp = Path(sys.argv[1])
+dir_out = Path(sys.argv[2])
 
 # try to load PyTorch binary data
 try:
@@ -24,10 +25,17 @@ try:
     with io.BytesIO(model_bytes) as fp:
         checkpoint = torch.load(fp, map_location="cpu")
 except Exception:
-    print("Error: failed to load PyTorch model file:" , fname_inp)
+    print("Error: failed to load PyTorch model file:", fname_inp)
     sys.exit(1)
 
-hparams = {"n_text_vocab": 256064, "n_audio_enc_dim": 1024, "n_audio_enc_ffn_dim": 4096, "n_audio_enc_feat_dim": 160, "n_audio_enc_layer": 24, "n_audio_enc_head": 16}
+hparams = {
+    "n_text_vocab": 256064,
+    "n_audio_enc_dim": 1024,
+    "n_audio_enc_ffn_dim": 4096,
+    "n_audio_enc_feat_dim": 160,
+    "n_audio_enc_layer": 24,
+    "n_audio_enc_head": 16,
+}
 print("hparams:", hparams)
 
 list_vars = checkpoint["model"]
@@ -44,19 +52,27 @@ if len(sys.argv) > 4:
 
 fout = fname_out.open("wb")
 
-fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", 0x67676D6C))  # magic: ggml in hex
 for key in hparams.keys():
     fout.write(struct.pack("i", hparams[key]))
 fout.write(struct.pack("i", use_f16))
 
 exclude_list = []
-exclude_list += [f"encoder.w2v_encoder.w2v_model.encoder.layers.{i}.conv_module.batch_norm.num_batches_tracked" for i in range(24)]
+exclude_list += [
+    f"encoder.w2v_encoder.w2v_model.encoder.layers.{i}.conv_module.batch_norm.num_batches_tracked"
+    for i in range(24)
+]
 
 for name in list_vars.keys():
-    if list_vars[name] is None or name in exclude_list or "adaptor" in name or "mask_emb" in name:
+    if (
+        list_vars[name] is None
+        or name in exclude_list
+        or "adaptor" in name
+        or "mask_emb" in name
+    ):
         continue
     data = list_vars[name].squeeze().numpy()
-    print("Processing variable: " , name ,  " with shape: ", data.shape)
+    print("Processing variable: ", name, " with shape: ", data.shape)
 
     n_dims = len(data.shape)
 
@@ -68,12 +84,12 @@ for name in list_vars.keys():
     # if 'pos_bias' in name:
     #     import pdb; pdb.set_trace()
     #     print(data.shape)
-    str_ = state_map[name].encode('utf-8')
+    str_ = state_map[name].encode("utf-8")
     fout.write(struct.pack("iii", n_dims, len(str_), ftype))
     for i in range(n_dims):
-        if '.layer_norm.weight' in name:
+        if ".layer_norm.weight" in name:
             print(data.shape)
-        fout.write(struct.pack("i", data.shape[n_dims-1-i]))
+        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
     fout.write(str_)
 
     # data
@@ -81,6 +97,5 @@ for name in list_vars.keys():
 
 fout.close()
 
-print("Done. Output file: " , fname_out)
+print("Done. Output file: ", fname_out)
 print("")
-

+ 24 - 15
ggml/examples/unity/convert_pt_states.py

@@ -1,37 +1,46 @@
 import torch
+
+
 def map_state_key(pytorch_key, layer_idx=None):
     # Replace the layer index first
     if layer_idx is not None:
         pytorch_key = pytorch_key.replace(f".layers.{layer_idx}.", "/")
-    
+
     # Replace common patterns in the state key
     translation_dict = {
         ".weight": "/w",
         ".bias": "/b",
-        ".running_mean": "/m", # /running_mean doesn't work
+        ".running_mean": "/m",  # /running_mean doesn't work
         ".running_var": "/v",
         ".num_batches_tracked": "/n",
         "self_attn.": "self_attn_",
         "conv_module.": "conv_",
         "ffn1.": "ffn1_",
         "ffn2.": "ffn2_",
-        "pos_conv.0": "pos_conv"
+        "pos_conv.0": "pos_conv",
     }
-    
-    
+
     # Special mapping for pos_bias_u and pos_bias_v
     if "self_attn.pos_bias_u" in pytorch_key:
-        pytorch_key = pytorch_key.replace("self_attn.pos_bias_u", "self_attn_pos_bias/u")
+        pytorch_key = pytorch_key.replace(
+            "self_attn.pos_bias_u", "self_attn_pos_bias/u"
+        )
     elif "self_attn.pos_bias_v" in pytorch_key:
-        pytorch_key = pytorch_key.replace("self_attn.pos_bias_v", "self_attn_pos_bias/v")
+        pytorch_key = pytorch_key.replace(
+            "self_attn.pos_bias_v", "self_attn_pos_bias/v"
+        )
     for pytorch_pattern, model_pattern in translation_dict.items():
         pytorch_key = pytorch_key.replace(pytorch_pattern, model_pattern)
-    
+
     # Replace the leading pattern and add layer index
     if layer_idx is not None:
-        pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder/", f"model/enc/h{layer_idx}/")
+        pytorch_key = pytorch_key.replace(
+            "encoder.w2v_encoder.w2v_model.encoder/", f"model/enc/h{layer_idx}/"
+        )
     else:
-        pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder.", f"model/enc/")
+        pytorch_key = pytorch_key.replace(
+            "encoder.w2v_encoder.w2v_model.encoder.", f"model/enc/"
+        )
     pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.", f"model/")
     return pytorch_key
 
@@ -49,12 +58,12 @@ def generate_mapping(state_dict):
 
 
 # Testing
-ckpt = torch.load('/large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt')
+ckpt = torch.load("/large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt")
 state_dict = {}
-for key in ckpt['model']:
-    if ckpt['model'][key] is not None:
-        state_dict[key] = ckpt['model'][key]
+for key in ckpt["model"]:
+    if ckpt["model"][key] is not None:
+        state_dict[key] = ckpt["model"][key]
 
 mapped_keys = generate_mapping(state_dict)
 for old_key, new_key in mapped_keys.items():
-    print(old_key, "=>", new_key)
+    print(old_key, "=>", new_key)