|
@@ -4,14 +4,14 @@
|
|
# This source code is licensed under the license found in the
|
|
# This source code is licensed under the license found in the
|
|
# MIT_LICENSE file in the root directory of this source tree.
|
|
# MIT_LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
-from typing import Any, Dict, List, Mapping, Tuple, Union
|
|
|
|
|
|
+from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
-from fairseq2.assets import AssetStore, asset_store, download_manager
|
|
|
|
|
|
+from fairseq2.assets import AssetStore, asset_store
|
|
from fairseq2.assets.card import AssetCard, AssetCardFieldNotFoundError
|
|
from fairseq2.assets.card import AssetCard, AssetCardFieldNotFoundError
|
|
-from fairseq2.models.nllb import NllbConfig
|
|
|
|
-from fairseq2.models.nllb.loader import NllbTokenizerLoader
|
|
|
|
-from fairseq2.models.utils import ConfigLoader, ModelLoader
|
|
|
|
|
|
+from fairseq2.models.nllb import NllbConfig, load_nllb_tokenizer
|
|
|
|
+from fairseq2.models import setup_model_family
|
|
|
|
+from fairseq2.data.text import register_text_tokenizer
|
|
from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
|
|
from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
|
|
|
|
|
|
from seamless_communication.models.unity.builder import (
|
|
from seamless_communication.models.unity.builder import (
|
|
@@ -20,13 +20,13 @@ from seamless_communication.models.unity.builder import (
|
|
unity_archs,
|
|
unity_archs,
|
|
)
|
|
)
|
|
from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
|
|
from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
|
|
-from seamless_communication.models.unity.model import UnitYModel
|
|
|
|
|
|
+from seamless_communication.models.unity.model import UNITY_FAMILY
|
|
from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
|
|
from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
|
|
|
|
|
|
|
|
|
|
def convert_unity_checkpoint(
|
|
def convert_unity_checkpoint(
|
|
- checkpoint: Mapping[str, Any], config: UnitYConfig
|
|
|
|
-) -> Mapping[str, Any]:
|
|
|
|
|
|
+ checkpoint: Dict[str, Any], config: UnitYConfig
|
|
|
|
+) -> Dict[str, Any]:
|
|
state_dict = checkpoint["model"]
|
|
state_dict = checkpoint["model"]
|
|
|
|
|
|
# Check if we have a fairseq2 checkpoint.
|
|
# Check if we have a fairseq2 checkpoint.
|
|
@@ -39,7 +39,11 @@ def convert_unity_checkpoint(
|
|
|
|
|
|
state_dict = checkpoint["model"]
|
|
state_dict = checkpoint["model"]
|
|
|
|
|
|
- keys_to_delete = []
|
|
|
|
|
|
+ keys_to_delete = [
|
|
|
|
+ "speech_encoder_frontend.pos_encoder.conv.bias",
|
|
|
|
+ "speech_encoder_frontend.pos_encoder.conv.weight_g",
|
|
|
|
+ "speech_encoder_frontend.pos_encoder.conv.weight_v",
|
|
|
|
+ ]
|
|
|
|
|
|
# ExpressiveUnitY model (from multi_arch codebase)
|
|
# ExpressiveUnitY model (from multi_arch codebase)
|
|
if config.prosody_encoder_config is not None:
|
|
if config.prosody_encoder_config is not None:
|
|
@@ -203,42 +207,42 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
|
|
# fmt: off
|
|
# fmt: off
|
|
|
|
|
|
# Speech Encoder
|
|
# Speech Encoder
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.": r"speech_encoder_frontend.pos_encoder.conv.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.": r"speech_encoder_frontend.post_extract_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.": r"speech_encoder_frontend.model_dim_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.": r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.": r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.": r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
|
|
|
|
-
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"speech_encoder.inner.layers.\1.conv.batch_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.": r"speech_encoder.inner.layers.\1.conv.layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"speech_encoder.inner.layers.\1.conv_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.": r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.inner.layers.\1.layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm.",
|
|
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.": "speech_encoder_frontend.pos_encoder.conv.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.": "speech_encoder_frontend.post_extract_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.": "speech_encoder_frontend.model_dim_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.": "speech_encoder_frontend.feature_extractor.layers.\\1.conv.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.": "speech_encoder_frontend.feature_extractor.layers.\\1.layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.": "speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
|
|
|
|
+
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": "speech_encoder.inner.layers.\\1.conv.batch_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.": "speech_encoder.inner.layers.\\1.conv.layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": "speech_encoder.inner.layers.\\1.conv.depthwise_conv.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": "speech_encoder.inner.layers.\\1.conv_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": "speech_encoder.inner.layers.\\1.conv.pointwise_conv1.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": "speech_encoder.inner.layers.\\1.conv.pointwise_conv2.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": "speech_encoder.inner.layers.\\1.ffn\\2_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": "speech_encoder.inner.layers.\\1.ffn\\2.inner_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": "speech_encoder.inner.layers.\\1.ffn\\2.output_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": "speech_encoder.inner.layers.\\1.self_attn_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.": "speech_encoder.inner.layers.\\1.self_attn.q_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": "speech_encoder.inner.layers.\\1.self_attn.k_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": "speech_encoder.inner.layers.\\1.self_attn.v_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": "speech_encoder.inner.layers.\\1.self_attn.output_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": "speech_encoder.inner.layers.\\1.self_attn.q_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": "speech_encoder.inner.layers.\\1.self_attn.k_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": "speech_encoder.inner.layers.\\1.self_attn.v_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.": "speech_encoder.inner.layers.\\1.self_attn.sdpa.rel_k_embed.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": "speech_encoder.inner.layers.\\1.self_attn.output_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": "speech_encoder.inner.layers.\\1.self_attn.sdpa.r_proj.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": "speech_encoder.inner.layers.\\1.self_attn.sdpa.u_bias",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": "speech_encoder.inner.layers.\\1.self_attn.sdpa.v_bias",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.": "speech_encoder.inner.layers.\\1.layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": "speech_encoder.inner.layer_norm.",
|
|
|
|
|
|
# Speech Encoder Adaptor
|
|
# Speech Encoder Adaptor
|
|
- fr"^{encoder_key}\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.out_ln\.": r"speech_encoder.layer_norm.",
|
|
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.proj\.0\.": "speech_encoder.proj1.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.proj\.2\.": "speech_encoder.proj2.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.out_ln\.": "speech_encoder.layer_norm.",
|
|
|
|
|
|
# Text Encoder
|
|
# Text Encoder
|
|
r"^text_encoder\.embed_tokens\.": r"text_encoder_frontend.embed.",
|
|
r"^text_encoder\.embed_tokens\.": r"text_encoder_frontend.embed.",
|
|
@@ -264,13 +268,13 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
|
|
if config.w2v2_encoder_config.use_conformer:
|
|
if config.w2v2_encoder_config.use_conformer:
|
|
key_map.update(
|
|
key_map.update(
|
|
{
|
|
{
|
|
- fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
|
|
|
|
|
|
+ fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": "speech_encoder.inner_layer_norm."
|
|
}
|
|
}
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
key_map.update(
|
|
key_map.update(
|
|
{
|
|
{
|
|
- rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
|
|
|
|
|
|
+ rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": "speech_encoder.inner.layer_norm."
|
|
}
|
|
}
|
|
)
|
|
)
|
|
# fmt: on
|
|
# fmt: on
|
|
@@ -279,20 +283,20 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
|
|
key_map.update(
|
|
key_map.update(
|
|
{
|
|
{
|
|
# fmt: off
|
|
# fmt: off
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.": r"speech_encoder.adaptor_layers.\1.block.self_attn.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.": r"speech_encoder.adaptor_layers.\1.layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.": r"speech_encoder.adaptor_layers.\1.conv.",
|
|
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": "speech_encoder.adaptor_layers.\\1.block.self_attn.output_proj.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.": "speech_encoder.adaptor_layers.\\1.block.self_attn.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": "speech_encoder.adaptor_layers.\\1.block.self_attn_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": "speech_encoder.adaptor_layers.\\1.block.ffn\\2_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": "speech_encoder.adaptor_layers.\\1.block.ffn\\2.inner_proj.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": "speech_encoder.adaptor_layers.\\1.block.ffn\\2.output_proj.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.": "speech_encoder.adaptor_layers.\\1.block.conv.batch_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": "speech_encoder.adaptor_layers.\\1.block.conv.depthwise_conv.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.": "speech_encoder.adaptor_layers.\\1.block.conv_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": "speech_encoder.adaptor_layers.\\1.block.conv.pointwise_conv1.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": "speech_encoder.adaptor_layers.\\1.block.conv.pointwise_conv2.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": "speech_encoder.adaptor_layers.\\1.block.layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.": "speech_encoder.adaptor_layers.\\1.layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.": "speech_encoder.adaptor_layers.\\1.conv.",
|
|
# fmt: on
|
|
# fmt: on
|
|
}
|
|
}
|
|
)
|
|
)
|
|
@@ -300,15 +304,15 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
|
|
key_map.update(
|
|
key_map.update(
|
|
{
|
|
{
|
|
# fmt: off
|
|
# fmt: off
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.": r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.": r"speech_encoder.adaptor_layers.\1.residual_conv.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.": r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.": r"speech_encoder.adaptor_layers.\1.self_attn.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.": r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.": r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
|
|
|
|
- fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
|
|
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.": "speech_encoder.adaptor_layers.\\1.residual_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.": "speech_encoder.adaptor_layers.\\1.residual_conv.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.": "speech_encoder.adaptor_layers.\\1.self_attn_conv.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.": "speech_encoder.adaptor_layers.\\1.self_attn.output_proj.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.": "speech_encoder.adaptor_layers.\\1.self_attn.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": "speech_encoder.adaptor_layers.\\1.self_attn_layer_norm.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.": "speech_encoder.adaptor_layers.\\1.ffn.inner_proj.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.": "speech_encoder.adaptor_layers.\\1.ffn.output_proj.",
|
|
|
|
+ fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.": "speech_encoder.adaptor_layers.\\1.ffn_layer_norm.",
|
|
# fmt: on
|
|
# fmt: on
|
|
}
|
|
}
|
|
)
|
|
)
|
|
@@ -389,20 +393,18 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
|
|
return key_map
|
|
return key_map
|
|
|
|
|
|
|
|
|
|
-load_unity_config = ConfigLoader[UnitYConfig](asset_store, unity_archs)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-load_unity_model = ModelLoader[UnitYModel, UnitYConfig](
|
|
|
|
- asset_store,
|
|
|
|
- download_manager,
|
|
|
|
- load_unity_config,
|
|
|
|
|
|
+load_unity_model, load_unity_config = setup_model_family(
|
|
|
|
+ UNITY_FAMILY,
|
|
|
|
+ UnitYConfig,
|
|
create_unity_model,
|
|
create_unity_model,
|
|
|
|
+ unity_archs,
|
|
convert_unity_checkpoint,
|
|
convert_unity_checkpoint,
|
|
restrict_checkpoints=False,
|
|
restrict_checkpoints=False,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+load_unity_text_tokenizer = load_nllb_tokenizer
|
|
|
|
|
|
-load_unity_text_tokenizer = NllbTokenizerLoader(asset_store, download_manager)
|
|
|
|
|
|
+register_text_tokenizer(UNITY_FAMILY, load_unity_text_tokenizer)
|
|
|
|
|
|
|
|
|
|
class UnitYUnitTokenizerLoader:
|
|
class UnitYUnitTokenizerLoader:
|