|
@@ -1,7 +1,8 @@
|
|
|
import torch
|
|
|
-def map_state_key(pytorch_key, layer_idx):
|
|
|
+def map_state_key(pytorch_key, layer_idx=None):
|
|
|
# Replace the layer index first
|
|
|
- pytorch_key = pytorch_key.replace(f".layers.{layer_idx}.", "/")
|
|
|
+ if layer_idx is not None:
|
|
|
+ pytorch_key = pytorch_key.replace(f".layers.{layer_idx}.", "/")
|
|
|
|
|
|
# Replace common patterns in the state key
|
|
|
translation_dict = {
|
|
@@ -14,6 +15,7 @@ def map_state_key(pytorch_key, layer_idx):
|
|
|
"conv_module.": "conv_",
|
|
|
"ffn1.": "ffn1_",
|
|
|
"ffn2.": "ffn2_",
|
|
|
+ "pos_conv.0": "pos_conv"
|
|
|
}
|
|
|
|
|
|
|
|
@@ -26,16 +28,23 @@ def map_state_key(pytorch_key, layer_idx):
|
|
|
pytorch_key = pytorch_key.replace(pytorch_pattern, model_pattern)
|
|
|
|
|
|
# Replace the leading pattern and add layer index
|
|
|
- return pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder/", f"model/h{layer_idx}/")
|
|
|
+ if layer_idx is not None:
|
|
|
+ pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder/", f"model/enc/h{layer_idx}/")
|
|
|
+ else:
|
|
|
+ pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder.", f"model/enc/")
|
|
|
+ pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.", f"model/")
|
|
|
+ return pytorch_key
|
|
|
|
|
|
|
|
|
def generate_mapping(state_dict):
|
|
|
mapping = {}
|
|
|
for state in state_dict.keys():
|
|
|
for layer_idx in range(24):
|
|
|
- if f".layers.{layer_idx}." in state:
|
|
|
+ if f".layers.{layer_idx}" in state:
|
|
|
new_key = map_state_key(state, layer_idx)
|
|
|
mapping[state] = new_key
|
|
|
+ if "layers" not in state:
|
|
|
+ mapping[state] = map_state_key(state)
|
|
|
return mapping
|
|
|
|
|
|
|