convert_pt_states.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import torch
  2. def map_state_key(pytorch_key, layer_idx=None):
  3. # Replace the layer index first
  4. if layer_idx is not None:
  5. pytorch_key = pytorch_key.replace(f".layers.{layer_idx}.", "/")
  6. # Replace common patterns in the state key
  7. translation_dict = {
  8. ".weight": "/w",
  9. ".bias": "/b",
  10. ".running_mean": "/m", # /running_mean doesn't work
  11. ".running_var": "/v",
  12. ".num_batches_tracked": "/n",
  13. "self_attn.": "self_attn_",
  14. "conv_module.": "conv_",
  15. "ffn1.": "ffn1_",
  16. "ffn2.": "ffn2_",
  17. "pos_conv.0": "pos_conv",
  18. }
  19. # Special mapping for pos_bias_u and pos_bias_v
  20. if "self_attn.pos_bias_u" in pytorch_key:
  21. pytorch_key = pytorch_key.replace(
  22. "self_attn.pos_bias_u", "self_attn_pos_bias/u"
  23. )
  24. elif "self_attn.pos_bias_v" in pytorch_key:
  25. pytorch_key = pytorch_key.replace(
  26. "self_attn.pos_bias_v", "self_attn_pos_bias/v"
  27. )
  28. for pytorch_pattern, model_pattern in translation_dict.items():
  29. pytorch_key = pytorch_key.replace(pytorch_pattern, model_pattern)
  30. # Replace the leading pattern and add layer index
  31. if layer_idx is not None:
  32. pytorch_key = pytorch_key.replace(
  33. "encoder.w2v_encoder.w2v_model.encoder/", f"model/enc/h{layer_idx}/"
  34. )
  35. else:
  36. pytorch_key = pytorch_key.replace(
  37. "encoder.w2v_encoder.w2v_model.encoder.", f"model/enc/"
  38. )
  39. pytorch_key = pytorch_key.replace("encoder.w2v_encoder.w2v_model.", f"model/")
  40. return pytorch_key
  41. def generate_mapping(state_dict):
  42. mapping = {}
  43. for state in state_dict.keys():
  44. for layer_idx in range(24):
  45. if f".layers.{layer_idx}" in state:
  46. new_key = map_state_key(state, layer_idx)
  47. mapping[state] = new_key
  48. if "layers" not in state:
  49. mapping[state] = map_state_key(state)
  50. return mapping
  51. # Testing
  52. ckpt = torch.load("/large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt")
  53. state_dict = {}
  54. for key in ckpt["model"]:
  55. if ckpt["model"][key] is not None:
  56. state_dict[key] = ckpt["model"][key]
  57. mapped_keys = generate_mapping(state_dict)
  58. for old_key, new_key in mapped_keys.items():
  59. print(old_key, "=>", new_key)