|
@@ -48,20 +48,29 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
|
|
|
keys_to_delete = []
|
|
|
|
|
|
- # Use the built-in version attribute of `torch.Module`.
|
|
|
- if config.t2u_config is None:
|
|
|
- keys_to_delete.append("decoder.version")
|
|
|
- keys_to_delete.append("decoder.embed_positions._float_tensor")
|
|
|
+ # X2T/S2T + T2U model.
|
|
|
+ if config.t2u_config is not None:
|
|
|
+ encoder_key = "encoder"
|
|
|
+ decoder_key = "target_letter_decoder"
|
|
|
+ # X2T model.
|
|
|
+ elif config.use_text_encoder:
|
|
|
+ encoder_key = "speech_encoder"
|
|
|
+ decoder_key = "shared_decoder"
|
|
|
+ # S2T model.
|
|
|
else:
|
|
|
- keys_to_delete.append("target_letter_decoder.version")
|
|
|
- keys_to_delete.append("target_letter_decoder.embed_positions._float_tensor")
|
|
|
+ encoder_key = "encoder"
|
|
|
+ decoder_key = "decoder"
|
|
|
+
|
|
|
+ # Use the built-in version attribute of `torch.Module`.
|
|
|
+ keys_to_delete.append(f"{decoder_key}.version")
|
|
|
+ keys_to_delete.append(f"{decoder_key}.embed_positions._float_tensor")
|
|
|
|
|
|
if config.use_text_encoder:
|
|
|
keys_to_delete.append("text_encoder.version")
|
|
|
keys_to_delete.append("text_encoder.embed_positions._float_tensor")
|
|
|
|
|
|
# Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
|
|
|
- keys_to_delete.append("encoder.w2v_encoder.w2v_model.mask_emb")
|
|
|
+ keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
|
|
|
|
|
|
# Delete AlignmentEncoder keys for inference.
|
|
|
alignment_encoder_keys = [
|
|
@@ -126,46 +135,59 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
|
|
|
@staticmethod
|
|
|
def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
|
|
|
+ # X2T/S2T + T2U model.
|
|
|
+ if config.t2u_config is not None:
|
|
|
+ encoder_key = "encoder"
|
|
|
+ decoder_key = "target_letter_decoder"
|
|
|
+ # X2T model.
|
|
|
+ elif config.use_text_encoder:
|
|
|
+ encoder_key = "speech_encoder"
|
|
|
+ decoder_key = "shared_decoder"
|
|
|
+ # S2T model.
|
|
|
+ else:
|
|
|
+ encoder_key = "encoder"
|
|
|
+ decoder_key = "decoder"
|
|
|
+
|
|
|
key_map = {
|
|
|
# fmt: off
|
|
|
|
|
|
# Speech Encoder
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.": r"speech_encoder_frontend.pos_encoder.conv.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.layer_norm\.": r"speech_encoder_frontend.post_extract_layer_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.post_extract_proj\.": r"speech_encoder_frontend.model_dim_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.": r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.": r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.": r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
|
|
|
-
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"speech_encoder.inner.layers.\1.conv.batch_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.": r"speech_encoder.inner.layers.\1.conv.layer_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"speech_encoder.inner.layers.\1.conv_layer_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.": r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.inner.layers.\1.layer_norm.",
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.": r"speech_encoder_frontend.pos_encoder.conv.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.": r"speech_encoder_frontend.post_extract_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.": r"speech_encoder_frontend.model_dim_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.": r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.": r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.": r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
|
|
|
+
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"speech_encoder.inner.layers.\1.conv.batch_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.": r"speech_encoder.inner.layers.\1.conv.layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"speech_encoder.inner.layers.\1.conv_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.": r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.inner.layers.\1.layer_norm.",
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm.",
|
|
|
|
|
|
# Speech Encoder Adaptor
|
|
|
- r"^encoder\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
|
|
|
- r"^encoder\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
|
|
|
- r"^encoder\.adaptor\.out_ln\.": r"speech_encoder.layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.out_ln\.": r"speech_encoder.layer_norm.",
|
|
|
|
|
|
# Text Encoder
|
|
|
r"^text_encoder\.embed_tokens\.": r"text_encoder_frontend.embed.",
|
|
@@ -188,90 +210,82 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
# a redundant `LayerNorm` right after the Conformer blocks. We mitigate
|
|
|
# that issue here by moving that `LayerNorm` to the adaptor block.
|
|
|
if config.w2v2_encoder_config.use_conformer:
|
|
|
+ # fmt: off
|
|
|
key_map.update(
|
|
|
{
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
|
|
|
}
|
|
|
)
|
|
|
else:
|
|
|
key_map.update(
|
|
|
{
|
|
|
- r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
|
|
|
+ rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
|
|
|
}
|
|
|
)
|
|
|
+ # fmt: on
|
|
|
|
|
|
- # fmt: off
|
|
|
if config.use_conformer_adaptor:
|
|
|
key_map.update(
|
|
|
{
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.": r"speech_encoder.adaptor_layers.\1.block.self_attn.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_ln\.": r"speech_encoder.adaptor_layers.\1.layer_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.": r"speech_encoder.adaptor_layers.\1.conv.",
|
|
|
+ # fmt: off
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.": r"speech_encoder.adaptor_layers.\1.block.self_attn.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.": r"speech_encoder.adaptor_layers.\1.layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.": r"speech_encoder.adaptor_layers.\1.conv.",
|
|
|
+ # fmt: on
|
|
|
}
|
|
|
)
|
|
|
else:
|
|
|
key_map.update(
|
|
|
{
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.": r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.": r"speech_encoder.adaptor_layers.\1.residual_conv.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.": r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.": r"speech_encoder.adaptor_layers.\1.self_attn.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.fc1\.": r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.fc2\.": r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
|
|
|
- r"^encoder\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
|
|
|
+ # fmt: off
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.": r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.": r"speech_encoder.adaptor_layers.\1.residual_conv.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.": r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.": r"speech_encoder.adaptor_layers.\1.self_attn.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.": r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.": r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
|
|
|
+ # fmt: on
|
|
|
}
|
|
|
)
|
|
|
|
|
|
- # S2T model.
|
|
|
- if config.t2u_config is None:
|
|
|
- key_map.update(
|
|
|
- {
|
|
|
- # Text Decoder
|
|
|
- r"^decoder\.embed_tokens\.": r"text_decoder_frontend.embed.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
|
|
|
- r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
|
|
|
- r"^decoder\.layer_norm\.": r"text_decoder.layer_norm.",
|
|
|
- r"^decoder\.output_projection\.": r"final_proj.",
|
|
|
- }
|
|
|
- )
|
|
|
- # S2T + T2U model.
|
|
|
- else:
|
|
|
+ key_map.update(
|
|
|
+ {
|
|
|
+ # fmt: off
|
|
|
+ # Text Decoder
|
|
|
+ fr"^{decoder_key}\.embed_tokens\.": r"text_decoder_frontend.embed.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
|
|
|
+ fr"^{decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
|
|
|
+ fr"^{decoder_key}\.layer_norm\.": r"text_decoder.layer_norm.",
|
|
|
+ fr"^{decoder_key}\.output_projection\.": r"final_proj.",
|
|
|
+ # fmt: on
|
|
|
+ }
|
|
|
+ )
|
|
|
+ # X2T/S2T + T2U model.
|
|
|
+ if config.t2u_config is not None:
|
|
|
key_map.update(
|
|
|
{
|
|
|
- # Text Decoder
|
|
|
- r"^target_letter_decoder\.embed_tokens\.": r"text_decoder_frontend.embed.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
|
|
|
- r"^target_letter_decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
|
|
|
- r"^target_letter_decoder\.layer_norm\.": r"text_decoder.layer_norm.",
|
|
|
- r"^target_letter_decoder\.output_projection\.": r"final_proj.",
|
|
|
-
|
|
|
+ # fmt: off
|
|
|
# T2U Encoder
|
|
|
r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
|
|
|
r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.": r"t2u_model.encoder.layers.\1.self_attn.",
|
|
@@ -305,9 +319,9 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.": r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
|
|
|
r"^decoder\.layer_norm\.": r"t2u_model.decoder.layer_norm.",
|
|
|
r"^decoder\.output_projection\.": r"t2u_model.final_proj.",
|
|
|
+ # fmt: on
|
|
|
}
|
|
|
)
|
|
|
- # fmt: on
|
|
|
|
|
|
return key_map
|
|
|
|