convert_pt_states.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. def map_state_key(pytorch_key, layer_idx):
  3. # Replace the layer index first
  4. pytorch_key = pytorch_key.replace(f".layers.{layer_idx}.", "/")
  5. # Replace common patterns in the state key
  6. translation_dict = {
  7. ".weight": "/w",
  8. ".bias": "/b",
  9. ".running_mean": "/m", # /running_mean doesn't work
  10. ".running_var": "/v",
  11. ".num_batches_tracked": "/n",
  12. "self_attn.": "self_attn_",
  13. "conv_module.": "conv_",
  14. "ffn1.": "ffn1_",
  15. "ffn2.": "ffn2_",
  16. }
  17. # Special mapping for pos_bias_u and pos_bias_v
  18. if "self_attn.pos_bias_u" in pytorch_key:
  19. pytorch_key = pytorch_key.replace("self_attn.pos_bias_u", "self_attn_pos_bias/u")
  20. elif "self_attn.pos_bias_v" in pytorch_key:
  21. pytorch_key = pytorch_key.replace("self_attn.pos_bias_v", "self_attn_pos_bias/v")
  22. for pytorch_pattern, model_pattern in translation_dict.items():
  23. pytorch_key = pytorch_key.replace(pytorch_pattern, model_pattern)
  24. # Replace the leading pattern and add layer index
  25. return pytorch_key.replace("encoder.w2v_encoder.w2v_model.encoder/", f"model/h{layer_idx}/")
  26. def generate_mapping(state_dict):
  27. mapping = {}
  28. for state in state_dict.keys():
  29. for layer_idx in range(24):
  30. if f".layers.{layer_idx}." in state:
  31. new_key = map_state_key(state, layer_idx)
  32. mapping[state] = new_key
  33. return mapping
  34. # Testing
  35. ckpt = torch.load('/large_experiments/seamless/ust/dnn/unity_large_audio_enc.pt')
  36. state_dict = {}
  37. for key in ckpt['model']:
  38. if ckpt['model'][key] is not None:
  39. state_dict[key] = ckpt['model'][key]
  40. mapped_keys = generate_mapping(state_dict)
  41. for old_key, new_key in mapped_keys.items():
  42. print(old_key, "=>", new_key)