|
@@ -46,31 +46,38 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
|
|
|
state_dict = checkpoint["model"]
|
|
|
|
|
|
+ keys_to_delete = []
|
|
|
+
|
|
|
# Use the built-in version attribute of `torch.Module`.
|
|
|
- del state_dict["target_letter_decoder.version"]
|
|
|
- del state_dict["target_letter_decoder.embed_positions._float_tensor"]
|
|
|
+ if config.t2u_config is None:
|
|
|
+ keys_to_delete.append("decoder.version")
|
|
|
+ keys_to_delete.append("decoder.embed_positions._float_tensor")
|
|
|
+ else:
|
|
|
+ keys_to_delete.append("target_letter_decoder.version")
|
|
|
+ keys_to_delete.append("target_letter_decoder.embed_positions._float_tensor")
|
|
|
|
|
|
if config.use_text_encoder:
|
|
|
- if "text_encoder.version" in state_dict:
|
|
|
- del state_dict["text_encoder.version"]
|
|
|
- if "text_encoder.embed_positions._float_tensor" in state_dict:
|
|
|
- del state_dict["text_encoder.embed_positions._float_tensor"]
|
|
|
+ keys_to_delete.append("text_encoder.version")
|
|
|
+ keys_to_delete.append("text_encoder.embed_positions._float_tensor")
|
|
|
|
|
|
# Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
|
|
|
- del state_dict["encoder.w2v_encoder.w2v_model.mask_emb"]
|
|
|
+ keys_to_delete.append("encoder.w2v_encoder.w2v_model.mask_emb")
|
|
|
|
|
|
# Delete AlignmentEncoder keys for inference.
|
|
|
alignment_encoder_keys = [
|
|
|
key for key in state_dict if key.startswith("decoder.alignment_encoder.")
|
|
|
]
|
|
|
- for key in alignment_encoder_keys:
|
|
|
- del state_dict[key]
|
|
|
+ keys_to_delete.extend(alignment_encoder_keys)
|
|
|
|
|
|
# Delete character-level projection for inference.
|
|
|
- for key in [
|
|
|
- "decoder_target_letter_decoder.proj.weight",
|
|
|
- "decoder_target_letter_decoder.proj.bias",
|
|
|
- ]:
|
|
|
+ keys_to_delete.extend(
|
|
|
+ [
|
|
|
+ "decoder_target_letter_decoder.proj.weight",
|
|
|
+ "decoder_target_letter_decoder.proj.bias",
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ for key in keys_to_delete:
|
|
|
if key in state_dict:
|
|
|
del state_dict[key]
|
|
|
|
|
@@ -131,6 +138,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.": r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
|
|
|
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"speech_encoder.inner.layers.\1.conv.batch_norm.",
|
|
|
+ r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.": r"speech_encoder.inner.layers.\1.conv.layer_norm.",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"speech_encoder.inner.layers.\1.conv_layer_norm.",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
|
|
@@ -143,6 +151,11 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
|
|
|
+ r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
+ r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
+ r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
|
|
|
+ r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
|
|
|
r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
|
|
@@ -166,54 +179,6 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
r"^text_encoder\.layers\.([0-9]+)\.fc2\.": r"text_encoder.layers.\1.ffn.output_proj.",
|
|
|
r"^text_encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_encoder.layers.\1.ffn_layer_norm.",
|
|
|
r"^text_encoder\.layer_norm\.": r"text_encoder.layer_norm.",
|
|
|
-
|
|
|
- # Text Decoder
|
|
|
- r"^target_letter_decoder\.embed_tokens\.": r"text_decoder_frontend.embed.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
|
|
|
- r"^target_letter_decoder\.layer_norm\.": r"text_decoder.layer_norm.",
|
|
|
- r"^target_letter_decoder\.output_projection\.": r"final_proj.",
|
|
|
-
|
|
|
- # T2U Encoder
|
|
|
- r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
|
|
|
- r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.": r"t2u_model.encoder.layers.\1.self_attn.",
|
|
|
- r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
|
|
|
- r"^synthesizer_encoder\.layers\.([0-9]+)\.fc1\.": r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
|
|
|
- r"^synthesizer_encoder\.layers\.([0-9]+)\.fc2\.": r"t2u_model.encoder.layers.\1.ffn.output_proj.",
|
|
|
- r"^synthesizer_encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
|
|
|
- r"^synthesizer_encoder\.layer_norm\.": r"t2u_model.encoder.layer_norm.",
|
|
|
-
|
|
|
- # T2U Decoder frontend
|
|
|
- r"^decoder\.embed_tokens_text\.": r"t2u_model.decoder_frontend.embed_char.",
|
|
|
- r"^decoder\.embed_tokens_unit\.": r"t2u_model.decoder_frontend.embed.",
|
|
|
- r"^decoder\.embed_tokens\.": r"t2u_model.decoder_frontend.embed.",
|
|
|
- r"^decoder\.var_adaptor\.duration_predictor\.": r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
|
|
|
- r"^decoder\.dec_pos_emb_alpha": r"t2u_model.decoder_frontend.pos_emb_alpha",
|
|
|
- r"^decoder\.dec_pos_emb_alpha_char": r"t2u_model.decoder_frontend.pos_emb_alpha_char",
|
|
|
-
|
|
|
- # T2U Decoder
|
|
|
- r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"t2u_model.decoder.layers.\1.self_attn.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.layer_norm\.": r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.fc1\.": r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.fc2\.": r"t2u_model.decoder.layers.\1.ffn.output_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.0\.": r"t2u_model.decoder.layers.\1.conv1d.conv1.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.2\.": r"t2u_model.decoder.layers.\1.conv1d.conv2.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.": r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
|
|
|
- r"^decoder\.layer_norm\.": r"t2u_model.decoder.layer_norm.",
|
|
|
- r"^decoder\.output_projection\.": r"t2u_model.final_proj.",
|
|
|
# fmt: on
|
|
|
}
|
|
|
|
|
@@ -269,6 +234,79 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
r"^encoder\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
|
|
|
}
|
|
|
)
|
|
|
+
|
|
|
+ # S2T model.
|
|
|
+ if config.t2u_config is None:
|
|
|
+ key_map.update(
|
|
|
+ {
|
|
|
+ # Text Decoder
|
|
|
+ r"^decoder\.embed_tokens\.": r"text_decoder_frontend.embed.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^decoder\.layer_norm\.": r"text_decoder.layer_norm.",
|
|
|
+ r"^decoder\.output_projection\.": r"final_proj.",
|
|
|
+ }
|
|
|
+ )
|
|
|
+ # S2T + T2U model.
|
|
|
+ else:
|
|
|
+ key_map.update(
|
|
|
+ {
|
|
|
+ # Text Decoder
|
|
|
+ r"^target_letter_decoder\.embed_tokens\.": r"text_decoder_frontend.embed.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
|
|
|
+ r"^target_letter_decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^target_letter_decoder\.layer_norm\.": r"text_decoder.layer_norm.",
|
|
|
+ r"^target_letter_decoder\.output_projection\.": r"final_proj.",
|
|
|
+
|
|
|
+ # T2U Encoder
|
|
|
+ r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.": r"t2u_model.encoder.layers.\1.self_attn.",
|
|
|
+ r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
|
|
|
+ r"^synthesizer_encoder\.layers\.([0-9]+)\.fc1\.": r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
|
|
|
+ r"^synthesizer_encoder\.layers\.([0-9]+)\.fc2\.": r"t2u_model.encoder.layers.\1.ffn.output_proj.",
|
|
|
+ r"^synthesizer_encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^synthesizer_encoder\.layer_norm\.": r"t2u_model.encoder.layer_norm.",
|
|
|
+
|
|
|
+ # T2U Decoder frontend
|
|
|
+ r"^decoder\.embed_tokens_text\.": r"t2u_model.decoder_frontend.embed_char.",
|
|
|
+ r"^decoder\.embed_tokens_unit\.": r"t2u_model.decoder_frontend.embed.",
|
|
|
+ r"^decoder\.embed_tokens\.": r"t2u_model.decoder_frontend.embed.",
|
|
|
+ r"^decoder\.var_adaptor\.duration_predictor\.": r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
|
|
|
+ r"^decoder\.dec_pos_emb_alpha": r"t2u_model.decoder_frontend.pos_emb_alpha",
|
|
|
+ r"^decoder\.dec_pos_emb_alpha_char": r"t2u_model.decoder_frontend.pos_emb_alpha_char",
|
|
|
+
|
|
|
+ # T2U Decoder
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"t2u_model.decoder.layers.\1.self_attn.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.layer_norm\.": r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.fc1\.": r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.fc2\.": r"t2u_model.decoder.layers.\1.ffn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.0\.": r"t2u_model.decoder.layers.\1.conv1d.conv1.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.2\.": r"t2u_model.decoder.layers.\1.conv1d.conv2.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.": r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
|
|
|
+ r"^decoder\.layer_norm\.": r"t2u_model.decoder.layer_norm.",
|
|
|
+ r"^decoder\.output_projection\.": r"t2u_model.final_proj.",
|
|
|
+ }
|
|
|
+ )
|
|
|
# fmt: on
|
|
|
|
|
|
return key_map
|