Pārlūkot izejas kodu

More encoder states loading & pos enc

cndn 1 gadu atpakaļ
vecāks
revīzija
ff5cbf57c9

+ 3 - 1
ggml/examples/unity/convert-pt-to-ggml.py

@@ -53,7 +53,7 @@ exclude_list = []
 exclude_list += [f"encoder.w2v_encoder.w2v_model.encoder.layers.{i}.conv_module.batch_norm.num_batches_tracked" for i in range(24)]
 
 for name in list_vars.keys():
-    if list_vars[name] is None or name in exclude_list or "adaptor" in name:
+    if list_vars[name] is None or name in exclude_list or "adaptor" in name or "mask_emb" in name:
         continue
     data = list_vars[name].squeeze().numpy()
     print("Processing variable: " , name ,  " with shape: ", data.shape)
@@ -71,6 +71,8 @@ for name in list_vars.keys():
     str_ = state_map[name].encode('utf-8')
     fout.write(struct.pack("iii", n_dims, len(str_), ftype))
     for i in range(n_dims):
+        if '.layer_norm.weight' in name:
+            print(data.shape)
         fout.write(struct.pack("i", data.shape[n_dims-1-i]))
     fout.write(str_)
 

+ 13 - 4
ggml/examples/unity/convert_pt_states.py

@@ -1,7 +1,8 @@
 import torch
-def map_state_key(pytorch_key, layer_idx):
+def map_state_key(pytorch_key, layer_idx=None):
     # Replace the layer index first
-    pytorch_key = pytorch_key.replace(f".layers.{layer_idx}.", "/")
+    if layer_idx is not None:
+        pytorch_key = pytorch_key.replace(f".layers.{layer_idx}.", "/")
     
     # Replace common patterns in the state key
     translation_dict = {
@@ -14,6 +15,7 @@ def map_state_key(pytorch_key, layer_idx):
         "conv_module.": "conv_",
         "ffn1.": "ffn1_",
         "ffn2.": "ffn2_",
+        "pos_conv.0": "pos_conv"
     }
     
     
@@ -26,16 +28,23 @@ def map_state_key(pytorch_key, layer_idx):
         pytorch_key = pytorch_key.replace(pytorch_pattern, model_pattern)
     
     # Replace the leading pattern and add layer index
-    return pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder/", f"model/h{layer_idx}/")
+    if layer_idx is not None:
+        pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder/", f"model/enc/h{layer_idx}/")
+    else:
+        pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder.", f"model/enc/")
+    pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.", f"model/")
+    return pytorch_key
 
 
 def generate_mapping(state_dict):
     mapping = {}
     for state in state_dict.keys():
         for layer_idx in range(24):
-            if f".layers.{layer_idx}." in state:
+            if f".layers.{layer_idx}" in state:
                 new_key = map_state_key(state, layer_idx)
                 mapping[state] = new_key
+        if "layers" not in state:
+            mapping[state] = map_state_key(state)
     return mapping