Browse Source

Move to generic loaders (#111)

Can Balioglu 1 year ago
parent
commit
2393016090
27 changed files with 562 additions and 608 deletions
  1. 1 1
      src/seamless_communication/cards/monotonic_decoder.yaml
  2. 1 1
      src/seamless_communication/cards/pretssel_v1.yaml
  3. 2 2
      src/seamless_communication/cards/seamlessM4T_v2_large.yaml
  4. 2 2
      src/seamless_communication/cards/seamless_expressivity.yaml
  5. 1 1
      src/seamless_communication/cards/seamless_streaming_monotonic_decoder.yaml
  6. 2 2
      src/seamless_communication/cards/seamless_streaming_unity.yaml
  7. 1 1
      src/seamless_communication/cards/unity_sans_decoder.yaml
  8. 1 1
      src/seamless_communication/cards/vocoder_mel.yaml
  9. 1 1
      src/seamless_communication/cards/vocoder_v2.yaml
  10. 5 8
      src/seamless_communication/models/monotonic_decoder/__init__.py
  11. 2 5
      src/seamless_communication/models/monotonic_decoder/builder.py
  12. 66 80
      src/seamless_communication/models/monotonic_decoder/loader.py
  13. 5 7
      src/seamless_communication/models/monotonic_decoder/model.py
  14. 4 4
      src/seamless_communication/models/monotonic_decoder/p_choose.py
  15. 1 1
      src/seamless_communication/models/pretssel/builder.py
  16. 1 1
      src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py
  17. 79 81
      src/seamless_communication/models/pretssel/loader.py
  18. 0 1
      src/seamless_communication/models/unity/__init__.py
  19. 1 2
      src/seamless_communication/models/unity/builder.py
  20. 0 1
      src/seamless_communication/models/unity/length_regulator.py
  21. 347 358
      src/seamless_communication/models/unity/loader.py
  22. 1 1
      src/seamless_communication/models/unity/t2u_builder.py
  23. 0 4
      src/seamless_communication/models/vocoder/__init__.py
  24. 4 4
      src/seamless_communication/models/vocoder/builder.py
  25. 31 35
      src/seamless_communication/models/vocoder/loader.py
  26. 2 2
      src/seamless_communication/models/vocoder/vocoder.py
  27. 1 1
      src/seamless_communication/models/wav2vec2_chunk/builder.py

+ 1 - 1
src/seamless_communication/cards/monotonic_decoder.yaml

@@ -7,4 +7,4 @@
 name: monotonic_decoder
 model_type: monotonic_decoder
 model_arch: dense_1b
-checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/monotonic_decoder.pt"
+checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/monotonic_decoder.pt"

+ 1 - 1
src/seamless_communication/cards/pretssel_v1.yaml

@@ -7,7 +7,7 @@
 name: pretssel_v1
 model_type: pretssel
 model_arch: base
-checkpoint: "file://checkpoint/mjhwang/experiments/230930-noiseaug_p2v-mls_multilingual_6lang/231005-noiseaug_p2v-mls_multilingual_6lang-alignfix.config_v2.langemb1.vuv_logit1.denoise.ngpu16/checkpoint_best.pt"
+checkpoint: "file:///checkpoint/mjhwang/experiments/230930-noiseaug_p2v-mls_multilingual_6lang/231005-noiseaug_p2v-mls_multilingual_6lang-alignfix.config_v2.langemb1.vuv_logit1.denoise.ngpu16/checkpoint_best.pt"
 num_units: 10000
 languages:
   - cmn

+ 2 - 2
src/seamless_communication/cards/seamlessM4T_v2_large.yaml

@@ -7,8 +7,8 @@
 name: seamlessM4T_v2_large
 base: unity_nllb-100
 model_arch: base_v2
-char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
-checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamlessM4T_v2_large.pt"
+char_tokenizer: "file:///checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamlessM4T_v2_large.pt"
 num_units: 10000
 unit_langs:
   - arb

+ 2 - 2
src/seamless_communication/cards/seamless_expressivity.yaml

@@ -7,8 +7,8 @@
 name: seamless_expressivity
 base: unity_nllb-100
 model_arch: expressivity_v2
-char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
-checkpoint: "file://checkpoint/hygong/Expressivity/multilingual_models/m2m.clean.ecapa_tdnn2.dim512.all.all.lr5e-05.mk4k.config_t2_fbank_nosa_gcmvn_10k.rdrop0.ls0.2.uf3.wu5k.fp16.mem_fp16.seed1.dr0.1.ld0.2.mp0.3.cmp0.25.ma.ak8.as8.al1.ald0.0.dld0.0.ca.D24L.t2uE4L.t2uD4L.usesfilm.inj_dec.ngpu64/checkpoint_best_export.pt"
+char_tokenizer: "file:///checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file:///checkpoint/hygong/Expressivity/multilingual_models/m2m.clean.ecapa_tdnn2.dim512.all.all.lr5e-05.mk4k.config_t2_fbank_nosa_gcmvn_10k.rdrop0.ls0.2.uf3.wu5k.fp16.mem_fp16.seed1.dr0.1.ld0.2.mp0.3.cmp0.25.ma.ak8.as8.al1.ald0.0.dld0.0.ca.D24L.t2uE4L.t2uD4L.usesfilm.inj_dec.ngpu64/checkpoint_best_export.pt"
 num_units: 10000
 unit_langs:
   - arb

+ 1 - 1
src/seamless_communication/cards/seamless_streaming_monotonic_decoder.yaml

@@ -7,4 +7,4 @@
 name: seamless_streaming_monotonic_decoder
 model_type: monotonic_decoder
 model_arch: dense_1b
-checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamless_streaming_monotonic_decoder.pt"
+checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamless_streaming_monotonic_decoder.pt"

+ 2 - 2
src/seamless_communication/cards/seamless_streaming_unity.yaml

@@ -7,8 +7,8 @@
 name: seamless_streaming_unity
 base: unity_nllb-100
 model_arch: base_v2
-char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
-checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamless_streaming_unity.pt"
+char_tokenizer: "file:///checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamless_streaming_unity.pt"
 num_units: 10000
 unit_langs:
   - arb

+ 1 - 1
src/seamless_communication/cards/unity_sans_decoder.yaml

@@ -7,4 +7,4 @@
 name: unity_sans_decoder
 base: unity_nllb-100
 model_arch: base_v2
-checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/unity_sans_decoder.pt"
+checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/unity_sans_decoder.pt"

+ 1 - 1
src/seamless_communication/cards/vocoder_mel.yaml

@@ -7,4 +7,4 @@
 name: vocoder_mel
 model_type: vocoder_mel_hifigan
 model_arch: base_mel
-checkpoint: "file://large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
+checkpoint: "file:///large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"

+ 1 - 1
src/seamless_communication/cards/vocoder_v2.yaml

@@ -7,7 +7,7 @@
 name: vocoder_v2
 model_type: vocoder_code_hifigan
 model_arch: base
-checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/vocoder_v2.pt"
+checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/vocoder_v2.pt"
 model_config: {
   "lang_spkr_idx_map": {
       "multilingual": {

+ 5 - 8
src/seamless_communication/models/monotonic_decoder/__init__.py

@@ -11,17 +11,14 @@ from seamless_communication.models.monotonic_decoder.builder import (
     MonotonicDecoderConfig as MonotonicDecoderConfig,
 )
 from seamless_communication.models.monotonic_decoder.builder import (
-    monotonic_decoder_archs as monotonic_decoder_archs,
+    create_monotonic_decoder_model as create_monotonic_decoder_model,
 )
-from seamless_communication.models.monotonic_decoder.loader import (
-    load_monotonic_decoder_model as load_monotonic_decoder_model,
+from seamless_communication.models.monotonic_decoder.builder import (
+    monotonic_decoder_archs as monotonic_decoder_archs,
 )
 from seamless_communication.models.monotonic_decoder.loader import (
     load_monotonic_decoder_config as load_monotonic_decoder_config,
 )
-from seamless_communication.models.monotonic_decoder.builder import (
-    create_monotonic_decoder_model as create_monotonic_decoder_model,
-)
-from seamless_communication.models.monotonic_decoder.builder import (
-    monotonic_decoder_archs as monotonic_decoder_archs,
+from seamless_communication.models.monotonic_decoder.loader import (
+    load_monotonic_decoder_model as load_monotonic_decoder_model,
 )

+ 2 - 5
src/seamless_communication/models/monotonic_decoder/builder.py

@@ -26,9 +26,6 @@ from fairseq2.nn.transformer import (
 )
 from fairseq2.typing import DataType, Device
 
-from seamless_communication.models.monotonic_decoder.p_choose import (
-    PChooseLayer,
-)
 from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
 from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
     MonotonicTransformerDecoder,
@@ -36,6 +33,7 @@ from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
 from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import (
     MonotonicTransformerDecoderLayer,
 )
+from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer
 
 
 @dataclass
@@ -83,12 +81,11 @@ monotonic_decoder_archs = ArchitectureRegistry[MonotonicDecoderConfig](
     "monotonic_decoder"
 )
 
-monotonic_decoder_arch = monotonic_decoder_archs.marker
+monotonic_decoder_arch = monotonic_decoder_archs.decorator
 
 
 @monotonic_decoder_arch("dense_1b")
 def _dense_1b() -> MonotonicDecoderConfig:
-
     return MonotonicDecoderConfig(
         model_dim=1024,
         max_seq_len=4096,

+ 66 - 80
src/seamless_communication/models/monotonic_decoder/loader.py

@@ -4,17 +4,12 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, Mapping, final
+from typing import Any, Mapping
 
 import torch
-
-from fairseq2.assets import (
-    asset_store,
-    download_manager,
-)
-from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
-from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader
-from fairseq2.typing import finaloverride
+from fairseq2.assets import asset_store, download_manager
+from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 
 from seamless_communication.models.monotonic_decoder.builder import (
     MonotonicDecoderConfig,
@@ -24,83 +19,74 @@ from seamless_communication.models.monotonic_decoder.builder import (
 from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
 
 
-@final
-class MonotonicDecoderLoader(
-    ModelLoader[MonotonicDecoderModel, MonotonicDecoderConfig]
-):
-    """Loads Monotonic Decoder models."""
-
-    @finaloverride
-    def _convert_checkpoint(
-        self, checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
-    ) -> Mapping[str, Any]:
-        state_dict = checkpoint["model"]
-
-        # Check if we have a fairseq2 checkpoint.
-        if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict:
-            return checkpoint
-
-        key_map = self._fairseq_key_map()
-
-        # Convert to fairseq2.
-        checkpoint = upgrade_fairseq_checkpoint(checkpoint, key_map)
-
-        state_dict = checkpoint["model"]
-
-        embeds = state_dict["final_proj.weight"]
-
-        # fairseq had a bug that accidentally introduced a dummy token in the
-        # embedding table of NLLB-100. We just discard it.
-        if embeds.size(0) == 256103:  # means NLLB-100
-            embeds = embeds[:-1]
+def convert_monotonic_checkpoint(
+    checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
+) -> Mapping[str, Any]:
+    state_dict = checkpoint["model"]
 
-            state_dict["final_proj.weight"] = embeds
-
-        # fairseq checkpoints have duplicate embedding weights. Ensure that we
-        # use a single embedding table in fairseq2.
-        state_dict["text_decoder_frontend.embed.weight"] = embeds
+    # Check if we have a fairseq2 checkpoint.
+    if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict:
+        return checkpoint
 
-        # The embedding positions of the control symbols in fairseq's dict do
-        # not match the SentencePiece model of the tokenizer.
-        with torch.inference_mode():
-            # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
-            embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
+    key_map = {
+        # fmt: off
+        r"^decoder\.embed_tokens\.":                                            r"text_decoder_frontend.embed.",
+        r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":                   r"text_decoder.layers.\1.self_attn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.self_attn\.":                             r"text_decoder.layers.\1.self_attn.",
+        r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":                  r"text_decoder.layers.\1.self_attn_layer_norm.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":                r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.energy_bias":               r"text_decoder.layers.\1.p_choose_layer.energy_bias",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.source_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.k_energy_proj.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.target_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.q_energy_proj.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":                          r"text_decoder.layers.\1.encoder_decoder_attn.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.":               r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+        r"^decoder\.layers\.([0-9]+)\.fc1\.":                                   r"text_decoder.layers.\1.ffn.inner_proj.",
+        r"^decoder\.layers\.([0-9]+)\.fc2\.":                                   r"text_decoder.layers.\1.ffn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":                      r"text_decoder.layers.\1.ffn_layer_norm.",
+        r"^decoder\.layer_norm\.":                                              r"text_decoder.layer_norm.",
+        r"^decoder\.output_projection\.":                                       r"final_proj.",
+        # fmt: on
+    }
+
+    # Convert to fairseq2.
+    checkpoint = convert_fairseq_checkpoint(checkpoint, key_map)
+
+    state_dict = checkpoint["model"]
+
+    embeds = state_dict["final_proj.weight"]
+
+    # fairseq had a bug that accidentally introduced a dummy token in the
+    # embedding table of NLLB-100. We just discard it.
+    if embeds.size(0) == 256103:  # means NLLB-100
+        embeds = embeds[:-1]
+
+        state_dict["final_proj.weight"] = embeds
+
+    # fairseq checkpoints have duplicate embedding weights. Ensure that we
+    # use a single embedding table in fairseq2.
+    state_dict["text_decoder_frontend.embed.weight"] = embeds
+
+    # The embedding positions of the control symbols in fairseq's dict do
+    # not match the SentencePiece model of the tokenizer.
+    with torch.inference_mode():
+        # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
+        embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
+
+    return checkpoint
+
+
+load_monotonic_decoder_config = ConfigLoader[MonotonicDecoderConfig](
+    asset_store, monotonic_decoder_archs
+)
 
-        return checkpoint
 
-    @staticmethod
-    def _fairseq_key_map() -> Dict[str, str]:
-        return {
-            # fmt: off
-            # Text Decoder
-            r"^decoder\.embed_tokens\.":                                            r"text_decoder_frontend.embed.",
-            r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":                   r"text_decoder.layers.\1.self_attn.output_proj.",
-            r"^decoder\.layers\.([0-9]+)\.self_attn\.":                             r"text_decoder.layers.\1.self_attn.",
-            r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":                  r"text_decoder.layers.\1.self_attn_layer_norm.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":                r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.energy_bias":               r"text_decoder.layers.\1.p_choose_layer.energy_bias",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.source_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.k_energy_proj.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.target_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.q_energy_proj.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":                          r"text_decoder.layers.\1.encoder_decoder_attn.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.":               r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-            r"^decoder\.layers\.([0-9]+)\.fc1\.":                                   r"text_decoder.layers.\1.ffn.inner_proj.",
-            r"^decoder\.layers\.([0-9]+)\.fc2\.":                                   r"text_decoder.layers.\1.ffn.output_proj.",
-            r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":                      r"text_decoder.layers.\1.ffn_layer_norm.",
-            r"^decoder\.layer_norm\.":                                              r"text_decoder.layer_norm.",
-            r"^decoder\.output_projection\.":                                       r"final_proj.",
-            # fmt: on
-        }
-
-
-load_monotonic_decoder_model = MonotonicDecoderLoader(
+load_monotonic_decoder_model = ModelLoader[
+    MonotonicDecoderModel, MonotonicDecoderConfig
+](
     asset_store,
     download_manager,
+    load_monotonic_decoder_config,
     create_monotonic_decoder_model,
-    monotonic_decoder_archs,
+    convert_monotonic_checkpoint,
     restrict_checkpoints=False,
 )
-
-
-load_monotonic_decoder_config = ModelConfigLoader[MonotonicDecoderConfig](
-    asset_store, monotonic_decoder_archs
-)

+ 5 - 7
src/seamless_communication/models/monotonic_decoder/model.py

@@ -4,17 +4,15 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from overrides import final as finaloverride
 from typing import Optional, Tuple, final
 
-
-from torch import Tensor
-from torch.nn import Module
-from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.models.transformer.frontend import TransformerFrontend
-
-from fairseq2.nn.projection import Projection
+from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.projection import Projection
+from overrides import final as finaloverride
+from torch import Tensor
+from torch.nn import Module
 
 from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
     MonotonicTransformerDecoder,

+ 4 - 4
src/seamless_communication/models/monotonic_decoder/p_choose.py

@@ -5,13 +5,13 @@
 # LICENSE file in the root directory of this source tree.
 
 from typing import Optional, final
-from torch import Tensor
-from torch.nn import AvgPool1d, Module, ModuleList, ReLU
-from torch.nn.parameter import Parameter
-import torch
 
+import torch
 from fairseq2.nn.projection import Linear
 from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import AvgPool1d, Module, ModuleList, ReLU
+from torch.nn.parameter import Parameter
 
 
 class EnergyProjection(Module):

+ 1 - 1
src/seamless_communication/models/pretssel/builder.py

@@ -97,7 +97,7 @@ class PretsselConfig:
 
 pretssel_archs = ArchitectureRegistry[PretsselConfig]("pretssel")
 
-pretssel_arch = pretssel_archs.marker
+pretssel_arch = pretssel_archs.decorator
 
 
 @pretssel_arch("base")

+ 1 - 1
src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py

@@ -29,7 +29,7 @@ class EcapaTDNNConfig:
 
 ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
 
-ecapa_tdnn_arch = ecapa_tdnn_archs.marker
+ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
 
 
 @ecapa_tdnn_arch("base")

+ 79 - 81
src/seamless_communication/models/pretssel/loader.py

@@ -4,12 +4,11 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, Mapping, final
+from typing import Any, Dict, Mapping
 
 from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
-from fairseq2.models.utils.model_loader import ModelLoader
-from overrides import override as finaloverride
+from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 
 from seamless_communication.models.pretssel.builder import (
     PretsselConfig,
@@ -19,96 +18,95 @@ from seamless_communication.models.pretssel.builder import (
 from seamless_communication.models.pretssel.pretssel_model import PretsselModel
 
 
-@final
-class PretsselLoader(ModelLoader[PretsselModel, PretsselConfig]):
-    """Load PretsselModel."""
+def convert_pretssel_checkpoint(
+    checkpoint: Mapping[str, Any], config: PretsselConfig
+) -> Mapping[str, Any]:
+    state_dict = checkpoint["model"]
 
-    @finaloverride
-    def _convert_checkpoint(
-        self, checkpoint: Mapping[str, Any], config: PretsselConfig
-    ) -> Mapping[str, Any]:
-        state_dict = checkpoint["model"]
+    # Check if we have a fairseq2 checkpoint.
+    if "decoder_frontend.embed.weight" in state_dict:
+        return checkpoint
 
-        # Check if we have a fairseq2 checkpoint.
-        if "decoder_frontend.embed.weight" in state_dict:
-            return checkpoint
+    key_map = _fairseq_key_map(config)
 
-        key_map = self._fairseq_key_map(config)
+    checkpoint = convert_fairseq_checkpoint(checkpoint, key_map)
 
-        checkpoint = upgrade_fairseq_checkpoint(checkpoint, key_map)
+    state_dict = checkpoint["model"]
 
-        state_dict = checkpoint["model"]
+    keys_to_delete = []
 
-        keys_to_delete = []
+    keys_to_delete.extend(
+        [
+            "encoder.embed_positions._float_tensor",
+            "decoder.embed_positions._float_tensor",
+            "enc_emb_proj.weight",
+            "enc_emb_proj.bias",
+        ]
+    )
 
-        keys_to_delete.extend(
-            [
-                "encoder.embed_positions._float_tensor",
-                "decoder.embed_positions._float_tensor",
-                "enc_emb_proj.weight",
-                "enc_emb_proj.bias",
-            ]
-        )
+    keys_to_delete.extend(
+        [
+            key
+            for key in state_dict
+            if key.startswith("decoder.var_adaptor.duration_predictor")
+        ]
+    )
 
-        keys_to_delete.extend(
-            [
-                key
-                for key in state_dict
-                if key.startswith("decoder.var_adaptor.duration_predictor")
-            ]
-        )
+    for key in keys_to_delete:
+        if key in state_dict:
+            del state_dict[key]
 
-        for key in keys_to_delete:
-            if key in state_dict:
-                del state_dict[key]
+    return checkpoint
 
-        return checkpoint
 
-    @staticmethod
-    def _fairseq_key_map(config: PretsselConfig) -> Dict[str, str]:
-        key_map = {
-            # fmt: off
-            # encoder frontend
-            r"^prosody_encoder\.":                                        r"encoder_frontend.prosody_encoder.",
-            r"^encoder\.embed_tokens\.":                                  r"encoder_frontend.embed_tokens.",
-            r"^embed_lang\.":                                             r"encoder_frontend.embed_lang.",
-            r"^encoder\.pos_emb_alpha":                                   r"encoder_frontend.pos_emb_alpha",
-
-            # encoder
-            r"^encoder\.fft_layers\.([0-9]+)\.self_attn\.out_proj\.":     r"encoder.layers.\1.self_attn.output_proj.",
-            r"^encoder\.fft_layers\.([0-9]+)\.self_attn\.":               r"encoder.layers.\1.self_attn.",
-            r"^encoder\.fft_layers\.([0-9]+)\.layer_norm\.":              r"encoder.layers.\1.self_attn_layer_norm.",
-            r"^encoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"encoder.layers.\1.conv1d.conv1.",
-            r"^encoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"encoder.layers.\1.conv1d.conv2.",
-            r"^encoder\.fft_layers\.([0-9]+)\.ffn\.layer_norm\.":         r"encoder.layers.\1.conv1d_layer_norm.",
-            r"^encoder\.fft_layers\.([0-9]+)\.film\.":                    r"encoder.layers.\1.film.",
-
-            # decoder frontend
-            r"^decoder\.var_adaptor\.":                                   r"decoder_frontend.variance_adaptor.",
-            r"^decoder\.pos_emb_alpha":                                   r"decoder_frontend.pos_emb_alpha",
-
-            # decoder
-            r"^decoder\.fft_layers\.([0-9]+)\.self_attn\.out_proj\.":     r"decoder.layers.\1.self_attn.output_proj.",
-            r"^decoder\.fft_layers\.([0-9]+)\.self_attn\.":               r"decoder.layers.\1.self_attn.",
-            r"^decoder\.fft_layers\.([0-9]+)\.layer_norm\.":              r"decoder.layers.\1.self_attn_layer_norm.",
-            r"^decoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"decoder.layers.\1.conv1d.conv1.",
-            r"^decoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"decoder.layers.\1.conv1d.conv2.",
-            r"^decoder\.fft_layers\.([0-9]+)\.ffn\.layer_norm\.":         r"decoder.layers.\1.conv1d_layer_norm.",
-            r"^decoder\.fft_layers\.([0-9]+)\.film\.":                    r"decoder.layers.\1.film.",
-
-            # final_proj & postnet
-            r"^decoder\.out_proj\.":                                      r"final_proj.",
-            r"^decoder\.postnet\.":                                       r"postnet.",
-            # fmt: on
-        }
-
-        return key_map
-
-
-load_pretssel_model = PretsselLoader(
+def _fairseq_key_map(config: PretsselConfig) -> Dict[str, str]:
+    key_map = {
+        # fmt: off
+        # encoder frontend
+        r"^prosody_encoder\.":                                        r"encoder_frontend.prosody_encoder.",
+        r"^encoder\.embed_tokens\.":                                  r"encoder_frontend.embed_tokens.",
+        r"^embed_lang\.":                                             r"encoder_frontend.embed_lang.",
+        r"^encoder\.pos_emb_alpha":                                   r"encoder_frontend.pos_emb_alpha",
+
+        # encoder
+        r"^encoder\.fft_layers\.([0-9]+)\.self_attn\.out_proj\.":     r"encoder.layers.\1.self_attn.output_proj.",
+        r"^encoder\.fft_layers\.([0-9]+)\.self_attn\.":               r"encoder.layers.\1.self_attn.",
+        r"^encoder\.fft_layers\.([0-9]+)\.layer_norm\.":              r"encoder.layers.\1.self_attn_layer_norm.",
+        r"^encoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"encoder.layers.\1.conv1d.conv1.",
+        r"^encoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"encoder.layers.\1.conv1d.conv2.",
+        r"^encoder\.fft_layers\.([0-9]+)\.ffn\.layer_norm\.":         r"encoder.layers.\1.conv1d_layer_norm.",
+        r"^encoder\.fft_layers\.([0-9]+)\.film\.":                    r"encoder.layers.\1.film.",
+
+        # decoder frontend
+        r"^decoder\.var_adaptor\.":                                   r"decoder_frontend.variance_adaptor.",
+        r"^decoder\.pos_emb_alpha":                                   r"decoder_frontend.pos_emb_alpha",
+
+        # decoder
+        r"^decoder\.fft_layers\.([0-9]+)\.self_attn\.out_proj\.":     r"decoder.layers.\1.self_attn.output_proj.",
+        r"^decoder\.fft_layers\.([0-9]+)\.self_attn\.":               r"decoder.layers.\1.self_attn.",
+        r"^decoder\.fft_layers\.([0-9]+)\.layer_norm\.":              r"decoder.layers.\1.self_attn_layer_norm.",
+        r"^decoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"decoder.layers.\1.conv1d.conv1.",
+        r"^decoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"decoder.layers.\1.conv1d.conv2.",
+        r"^decoder\.fft_layers\.([0-9]+)\.ffn\.layer_norm\.":         r"decoder.layers.\1.conv1d_layer_norm.",
+        r"^decoder\.fft_layers\.([0-9]+)\.film\.":                    r"decoder.layers.\1.film.",
+
+        # final_proj & postnet
+        r"^decoder\.out_proj\.":                                      r"final_proj.",
+        r"^decoder\.postnet\.":                                       r"postnet.",
+        # fmt: on
+    }
+
+    return key_map
+
+
+load_pretssel_config = ConfigLoader[PretsselConfig](asset_store, pretssel_archs)
+
+
+load_pretssel_model = ModelLoader[PretsselModel, PretsselConfig](
     asset_store,
     download_manager,
+    load_pretssel_config,
     create_pretssel_model,
-    pretssel_archs,
+    convert_pretssel_checkpoint,
     restrict_checkpoints=False,
 )

+ 0 - 1
src/seamless_communication/models/unity/__init__.py

@@ -36,7 +36,6 @@ from seamless_communication.models.unity.length_regulator import (
 from seamless_communication.models.unity.length_regulator import (
     VariancePredictor as VariancePredictor,
 )
-from seamless_communication.models.unity.loader import UnitYLoader as UnitYLoader
 from seamless_communication.models.unity.loader import (
     load_gcmvn_stats as load_gcmvn_stats,
 )

+ 1 - 2
src/seamless_communication/models/unity/builder.py

@@ -102,8 +102,7 @@ class UnitYConfig:
 
 unity_archs = ArchitectureRegistry[UnitYConfig]("unity")
 
-
-unity_arch = unity_archs.marker
+unity_arch = unity_archs.decorator
 
 
 @unity_arch("base")

+ 0 - 1
src/seamless_communication/models/unity/length_regulator.py

@@ -283,7 +283,6 @@ class VarianceAdaptor(Module):
         min_duration: int = 0,
         film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, PaddingMask]:
-
         if self.duration_predictor is not None:
             log_durations = self.duration_predictor(seqs, padding_mask, film_cond_emb)
             durations = torch.clamp(

+ 347 - 358
src/seamless_communication/models/unity/loader.py

@@ -4,16 +4,15 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, List, Mapping, Tuple, Union, final
+from typing import Any, Dict, List, Mapping, Tuple, Union
 
 import torch
 from fairseq2.assets import AssetStore, asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
-from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
-from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader
-from overrides import override as finaloverride
+from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 
 from seamless_communication.models.unity.builder import (
     UnitYConfig,
@@ -25,394 +24,384 @@ from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 
 
-@final
-class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
-    """Loads UnitY models."""
+def convert_unity_checkpoint(
+    checkpoint: Mapping[str, Any], config: UnitYConfig
+) -> Mapping[str, Any]:
+    state_dict = checkpoint["model"]
 
-    @finaloverride
-    def _convert_checkpoint(
-        self, checkpoint: Mapping[str, Any], config: UnitYConfig
-    ) -> Mapping[str, Any]:
-        state_dict = checkpoint["model"]
+    # Check if we have a fairseq2 checkpoint.
+    if "speech_encoder.inner.layers.0.self_attn_layer_norm.weight" in state_dict:
+        return checkpoint
 
-        # Check if we have a fairseq2 checkpoint.
-        if "speech_encoder.inner.layers.0.self_attn_layer_norm.weight" in state_dict:
-            return checkpoint
+    key_map = _fairseq_key_map(config)
+
+    checkpoint = convert_fairseq_checkpoint(checkpoint, key_map)
+
+    state_dict = checkpoint["model"]
+
+    keys_to_delete = []
+
+    # ExpressiveUnitY model (from multi_arch codebase)
+    if config.prosody_encoder_config is not None:
+        encoder_key = "s2t_model.encoder"
+        decoder_key = "s2t_model.decoder"
+        t2u_decoder_key = "t2s_model.decoder"
+    # X2T/S2T + T2U model.
+    elif config.t2u_config is not None:
+        encoder_key = "encoder"
+        decoder_key = "target_letter_decoder"
+        t2u_decoder_key = "decoder"
+    # X2T model.
+    elif config.use_text_encoder:
+        encoder_key = "speech_encoder"
+        decoder_key = "shared_decoder"
+    # S2T model.
+    else:
+        encoder_key = "encoder"
+        decoder_key = "decoder"
+
+    keys_to_delete.append(f"{decoder_key}.version")
+    keys_to_delete.append(f"{decoder_key}.embed_positions._float_tensor")
+
+    if config.use_text_encoder:
+        keys_to_delete.append("text_encoder.version")
+        keys_to_delete.append("text_encoder.embed_positions._float_tensor")
+
+    if not config.use_text_decoder:
+        text_decoder_keys = [key for key in state_dict if key.startswith(decoder_key)]
+        keys_to_delete.extend(text_decoder_keys)
+
+    # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
+    keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
+
+    if config.prosody_encoder_config is not None or config.t2u_config is not None:
+        keys_to_delete.append(
+            f"{t2u_decoder_key}.char_upsampler.embed_positions._float_tensor"
+        )
+        keys_to_delete.append(
+            f"{t2u_decoder_key}.char_upsampler.embed_tokens_char.weight"
+        )
 
-        key_map = self._fairseq_key_map(config)
+        # Delete AlignmentEncoder keys for inference.
+        alignment_encoder_keys = [
+            key
+            for key in state_dict
+            if key.startswith(f"{t2u_decoder_key}.alignment_encoder.")
+        ]
+        keys_to_delete.extend(alignment_encoder_keys)
 
-        checkpoint = upgrade_fairseq_checkpoint(checkpoint, key_map)
+    # Delete character-level projection for inference.
+    keys_to_delete.extend(
+        [
+            "decoder_target_letter_decoder.proj.weight",
+            "decoder_target_letter_decoder.proj.bias",
+        ]
+    )
 
-        state_dict = checkpoint["model"]
+    if config.prosody_encoder_config is not None:
+        keys_to_delete.extend(
+            [
+                f"{t2u_decoder_key}.embed_positions._float_tensor",
+                "t2s_model.global_proj_dec.weight",
+                "t2s_model.global_proj_dec.bias",
+                "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.weight",
+                "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.bias",
+            ]
+        )
 
-        keys_to_delete = []
+    for key in keys_to_delete:
+        if key in state_dict:
+            del state_dict[key]
 
-        # ExpressiveUnitY model (from multi_arch codebase)
-        if config.prosody_encoder_config is not None:
-            encoder_key = "s2t_model.encoder"
-            decoder_key = "s2t_model.decoder"
-            t2u_decoder_key = "t2s_model.decoder"
-        # X2T/S2T + T2U model.
-        elif config.t2u_config is not None:
-            encoder_key = "encoder"
-            decoder_key = "target_letter_decoder"
-            t2u_decoder_key = "decoder"
-        # X2T model.
-        elif config.use_text_encoder:
-            encoder_key = "speech_encoder"
-            decoder_key = "shared_decoder"
-        # S2T model.
-        else:
-            encoder_key = "encoder"
-            decoder_key = "decoder"
+    if config.use_text_decoder:
+        embeds = state_dict["final_proj.weight"]
 
-        keys_to_delete.append(f"{decoder_key}.version")
-        keys_to_delete.append(f"{decoder_key}.embed_positions._float_tensor")
+        # fairseq had a bug that accidentally introduced a dummy token in the
+        # embedding table of NLLB-100. We just discard it.
+        if (
+            isinstance(config.mt_model_config, NllbConfig) and embeds.size(0) == 256103
+        ):  # means NLLB-100
+            embeds = embeds[:-1]
 
-        if config.use_text_encoder:
-            keys_to_delete.append("text_encoder.version")
-            keys_to_delete.append("text_encoder.embed_positions._float_tensor")
+            state_dict["final_proj.weight"] = embeds
 
-        if not config.use_text_decoder:
-            text_decoder_keys = [
-                key for key in state_dict if key.startswith(decoder_key)
-            ]
-            keys_to_delete.extend(text_decoder_keys)
-
-        # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
-        keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
-
-        if config.prosody_encoder_config is not None or config.t2u_config is not None:
-            keys_to_delete.append(
-                f"{t2u_decoder_key}.char_upsampler.embed_positions._float_tensor"
-            )
-            keys_to_delete.append(
-                f"{t2u_decoder_key}.char_upsampler.embed_tokens_char.weight"
-            )
-
-            # Delete AlignmentEncoder keys for inference.
-            alignment_encoder_keys = [
-                key
-                for key in state_dict
-                if key.startswith(f"{t2u_decoder_key}.alignment_encoder.")
-            ]
-            keys_to_delete.extend(alignment_encoder_keys)
+        # fairseq checkpoints have duplicate embedding weights. Ensure that we
+        # use a single embedding table in fairseq2.
+        state_dict["text_decoder_frontend.embed.weight"] = embeds
 
-        # Delete character-level projection for inference.
-        keys_to_delete.extend(
-            [
-                "decoder_target_letter_decoder.proj.weight",
-                "decoder_target_letter_decoder.proj.bias",
-            ]
+        if config.use_text_encoder:
+            state_dict["text_encoder_frontend.embed.weight"] = embeds
+
+        # The embedding positions of the control symbols in fairseq's dict do
+        # not match the SentencePiece model of the tokenizer.
+        with torch.inference_mode():
+            # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
+            embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
+
+    char_embeds = state_dict.get("t2u_model.decoder_frontend.embed_char.weight", None)
+    if char_embeds is not None:
+        index_mapping = _get_char_index_mapping(config)
+        vocab_size = len(index_mapping)
+        char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
+
+    if config.t2u_config is not None:
+        # fairseq checkpoints have duplicate embedding weights. Ensure that we
+        # use a single embedding table in fairseq2.
+        embeds = state_dict["t2u_model.final_proj.weight"]
+
+        if "t2u_model.decoder_frontend.embed.weight" in state_dict:
+            state_dict["t2u_model.decoder_frontend.embed.weight"] = embeds
+
+    return checkpoint
+
+
+def _get_char_index_mapping(config: UnitYConfig) -> List[int]:
+    assert config.t2u_config is not None
+    assert config.t2u_config.nar_decoder_config is not None
+    char_tokenizer = load_unity_char_tokenizer(
+        config.t2u_config.nar_decoder_config.model_name_or_card
+    )
+    spm_order = [
+        char_tokenizer.model.index_to_token(i)
+        for i in range(char_tokenizer.model.vocabulary_size)
+    ][4:]
+    spm_to_dict_mapping = {
+        ch: idx
+        for (idx, ch) in zip(
+            range(4, char_tokenizer.model.vocabulary_size),
+            sorted(spm_order),
         )
+    }
+    model_to_dict_mapping = [0, 1, 2, 3] + [spm_to_dict_mapping[ch] for ch in spm_order]
+    return model_to_dict_mapping
+
+
+def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
+    # ExpressiveUnitY model (from multi_arch codebase)
+    if config.prosody_encoder_config is not None:
+        encoder_key = "s2t_model.encoder"
+        decoder_key = "s2t_model.decoder"
+        t2u_encoder_key = "t2s_model.encoder"
+        t2u_decoder_key = "t2s_model.decoder"
+        ecapa_tdnn_key = "global_prosody"
+    # X2T/S2T + T2U model.
+    elif config.t2u_config is not None:
+        encoder_key = "encoder"
+        decoder_key = "target_letter_decoder"
+        t2u_encoder_key = "synthesizer_encoder"
+        t2u_decoder_key = "decoder"
+    # X2T model.
+    elif config.use_text_encoder:
+        encoder_key = "speech_encoder"
+        decoder_key = "shared_decoder"
+    # S2T model.
+    else:
+        encoder_key = "encoder"
+        decoder_key = "decoder"
+
+    key_map = {
+        # fmt: off
 
-        if config.prosody_encoder_config is not None:
-            keys_to_delete.extend(
-                [
-                    f"{t2u_decoder_key}.embed_positions._float_tensor",
-                    "t2s_model.global_proj_dec.weight",
-                    "t2s_model.global_proj_dec.bias",
-                    "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.weight",
-                    "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.bias",
-                ]
-            )
-
-        for key in keys_to_delete:
-            if key in state_dict:
-                del state_dict[key]
-
-        if config.use_text_decoder:
-            embeds = state_dict["final_proj.weight"]
-
-            # fairseq had a bug that accidentally introduced a dummy token in the
-            # embedding table of NLLB-100. We just discard it.
-            if (
-                isinstance(config.mt_model_config, NllbConfig)
-                and embeds.size(0) == 256103
-            ):  # means NLLB-100
-                embeds = embeds[:-1]
-
-                state_dict["final_proj.weight"] = embeds
-
-            # fairseq checkpoints have duplicate embedding weights. Ensure that we
-            # use a single embedding table in fairseq2.
-            state_dict["text_decoder_frontend.embed.weight"] = embeds
-
-            if config.use_text_encoder:
-                state_dict["text_encoder_frontend.embed.weight"] = embeds
-
-            # The embedding positions of the control symbols in fairseq's dict do
-            # not match the SentencePiece model of the tokenizer.
-            with torch.inference_mode():
-                # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
-                embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
-
-        char_embeds = state_dict.get(
-            "t2u_model.decoder_frontend.embed_char.weight", None
+        # Speech Encoder
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    r"speech_encoder_frontend.pos_encoder.conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.":                                              r"speech_encoder_frontend.post_extract_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       r"speech_encoder_frontend.model_dim_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
+
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     r"speech_encoder.inner.layers.\1.conv.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.inner.layers.\1.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     r"speech_encoder.inner.layer_norm.",
+
+        # Speech Encoder Adaptor
+        fr"^{encoder_key}\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
+        fr"^{encoder_key}\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
+        fr"^{encoder_key}\.adaptor\.out_ln\.":  r"speech_encoder.layer_norm.",
+
+        # Text Encoder
+        r"^text_encoder\.embed_tokens\.":                              r"text_encoder_frontend.embed.",
+        r"^text_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_encoder.layers.\1.self_attn.output_proj.",
+        r"^text_encoder\.layers\.([0-9]+)\.self_attn\.":               r"text_encoder.layers.\1.self_attn.",
+        r"^text_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_encoder.layers.\1.self_attn_layer_norm.",
+        r"^text_encoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_encoder.layers.\1.encoder_decoder_attn.output_proj.",
+        r"^text_encoder\.layers\.([0-9]+)\.encoder_attn\.":            r"text_encoder.layers.\1.encoder_decoder_attn.",
+        r"^text_encoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_encoder.layers.\1.encoder_decoder_attn_layer_norm.",
+        r"^text_encoder\.layers\.([0-9]+)\.fc1\.":                     r"text_encoder.layers.\1.ffn.inner_proj.",
+        r"^text_encoder\.layers\.([0-9]+)\.fc2\.":                     r"text_encoder.layers.\1.ffn.output_proj.",
+        r"^text_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_encoder.layers.\1.ffn_layer_norm.",
+        r"^text_encoder\.layer_norm\.":                                r"text_encoder.layer_norm.",
+        # fmt: on
+    }
+
+    # In normal circumstances, we should never encounter a `LayerNorm` when
+    # `use_conformer` is `True`. Unfortunately, the w2v-BERT pretraining in
+    # fairseq was accidentally run with a pre-LN encoder, and ended up with
+    # a redundant `LayerNorm` right after the Conformer blocks. We mitigate
+    # that issue here by moving that `LayerNorm` to the adaptor block.
+    # fmt: off
+    if config.w2v2_encoder_config.use_conformer:
+        key_map.update(
+            {
+                fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
+            }
         )
-        if char_embeds is not None:
-            index_mapping = self._get_char_index_mapping(config)
-            vocab_size = len(index_mapping)
-            char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
-
-        if config.t2u_config is not None:
-            # fairseq checkpoints have duplicate embedding weights. Ensure that we
-            # use a single embedding table in fairseq2.
-            embeds = state_dict["t2u_model.final_proj.weight"]
-
-            if "t2u_model.decoder_frontend.embed.weight" in state_dict:
-                state_dict["t2u_model.decoder_frontend.embed.weight"] = embeds
-
-        return checkpoint
+    else:
+        key_map.update(
+            {
+                rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
+            }
+        )
+    # fmt: on
 
-    def _get_char_index_mapping(self, config: UnitYConfig) -> List[int]:
-        assert config.t2u_config is not None
-        assert config.t2u_config.nar_decoder_config is not None
-        char_tokenizer = load_unity_char_tokenizer(
-            config.t2u_config.nar_decoder_config.model_name_or_card
+    if config.use_conformer_adaptor:
+        key_map.update(
+            {
+                # fmt: off
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    r"speech_encoder.adaptor_layers.\1.block.self_attn.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      r"speech_encoder.adaptor_layers.\1.layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 r"speech_encoder.adaptor_layers.\1.conv.",
+                # fmt: on
+            }
+        )
+    else:
+        key_map.update(
+            {
+                # fmt: off
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     r"speech_encoder.adaptor_layers.\1.residual_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":            r"speech_encoder.adaptor_layers.\1.self_attn.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.":                  r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.":                  r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
+                # fmt: on
+            }
         )
-        spm_order = [
-            char_tokenizer.model.index_to_token(i)
-            for i in range(char_tokenizer.model.vocabulary_size)
-        ][4:]
-        spm_to_dict_mapping = {
-            ch: idx
-            for (idx, ch) in zip(
-                range(4, char_tokenizer.model.vocabulary_size),
-                sorted(spm_order),
-            )
-        }
-        model_to_dict_mapping = [0, 1, 2, 3] + [
-            spm_to_dict_mapping[ch] for ch in spm_order
-        ]
-        return model_to_dict_mapping
-
-    @staticmethod
-    def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
-        # ExpressiveUnitY model (from multi_arch codebase)
-        if config.prosody_encoder_config is not None:
-            encoder_key = "s2t_model.encoder"
-            decoder_key = "s2t_model.decoder"
-            t2u_encoder_key = "t2s_model.encoder"
-            t2u_decoder_key = "t2s_model.decoder"
-            ecapa_tdnn_key = "global_prosody"
-        # X2T/S2T + T2U model.
-        elif config.t2u_config is not None:
-            encoder_key = "encoder"
-            decoder_key = "target_letter_decoder"
-            t2u_encoder_key = "synthesizer_encoder"
-            t2u_decoder_key = "decoder"
-        # X2T model.
-        elif config.use_text_encoder:
-            encoder_key = "speech_encoder"
-            decoder_key = "shared_decoder"
-        # S2T model.
-        else:
-            encoder_key = "encoder"
-            decoder_key = "decoder"
 
-        key_map = {
+    key_map.update(
+        {
             # fmt: off
-
-            # Speech Encoder
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    r"speech_encoder_frontend.pos_encoder.conv.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.":                                              r"speech_encoder_frontend.post_extract_layer_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       r"speech_encoder_frontend.model_dim_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
-
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     r"speech_encoder.inner.layers.\1.conv.layer_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.inner.layers.\1.layer_norm.",
-            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     r"speech_encoder.inner.layer_norm.",
-
-            # Speech Encoder Adaptor
-            fr"^{encoder_key}\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
-            fr"^{encoder_key}\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
-            fr"^{encoder_key}\.adaptor\.out_ln\.":  r"speech_encoder.layer_norm.",
-
-            # Text Encoder
-            r"^text_encoder\.embed_tokens\.":                              r"text_encoder_frontend.embed.",
-            r"^text_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_encoder.layers.\1.self_attn.output_proj.",
-            r"^text_encoder\.layers\.([0-9]+)\.self_attn\.":               r"text_encoder.layers.\1.self_attn.",
-            r"^text_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_encoder.layers.\1.self_attn_layer_norm.",
-            r"^text_encoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_encoder.layers.\1.encoder_decoder_attn.output_proj.",
-            r"^text_encoder\.layers\.([0-9]+)\.encoder_attn\.":            r"text_encoder.layers.\1.encoder_decoder_attn.",
-            r"^text_encoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_encoder.layers.\1.encoder_decoder_attn_layer_norm.",
-            r"^text_encoder\.layers\.([0-9]+)\.fc1\.":                     r"text_encoder.layers.\1.ffn.inner_proj.",
-            r"^text_encoder\.layers\.([0-9]+)\.fc2\.":                     r"text_encoder.layers.\1.ffn.output_proj.",
-            r"^text_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_encoder.layers.\1.ffn_layer_norm.",
-            r"^text_encoder\.layer_norm\.":                                r"text_encoder.layer_norm.",
+            # Text Decoder
+            fr"^{decoder_key}\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
+            fr"^{decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
+            fr"^{decoder_key}\.layer_norm\.":                                r"text_decoder.layer_norm.",
+            fr"^{decoder_key}\.output_projection\.":                         r"final_proj.",
             # fmt: on
         }
+    )
+    # ExpressiveUnitY model (from multi_arch codebase)
+    if config.prosody_encoder_config is not None:
+        key_map.update(
+            {
+                # fmt: off
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.film\.":              r"t2u_model.decoder.layers.\1.film.",
+                fr"^{ecapa_tdnn_key}\.":                                       r"prosody_encoder_model.",
+                r"^t2s_model\.global_proj_enc\.":                             r"t2u_model.prosody_proj.",
+                # fmt: on
+            }
+        )
 
-        # In normal circumstances, we should never encounter a `LayerNorm` when
-        # `use_conformer` is `True`. Unfortunately, the w2v-BERT pretraining in
-        # fairseq was accidentally run with a pre-LN encoder, and ended up with
-        # a redundant `LayerNorm` right after the Conformer blocks. We mitigate
-        # that issue here by moving that `LayerNorm` to the adaptor block.
-        # fmt: off
-        if config.w2v2_encoder_config.use_conformer:
-            key_map.update(
-                {
-                    fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
-                }
-            )
-        else:
-            key_map.update(
-                {
-                    rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
-                }
-            )
-        # fmt: on
-
-        if config.use_conformer_adaptor:
-            key_map.update(
-                {
-                    # fmt: off
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    r"speech_encoder.adaptor_layers.\1.block.self_attn.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      r"speech_encoder.adaptor_layers.\1.layer_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 r"speech_encoder.adaptor_layers.\1.conv.",
-                    # fmt: on
-                }
-            )
-        else:
-            key_map.update(
-                {
-                    # fmt: off
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     r"speech_encoder.adaptor_layers.\1.residual_conv.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":            r"speech_encoder.adaptor_layers.\1.self_attn.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.":                  r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.":                  r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
-                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
-                    # fmt: on
-                }
-            )
-
+    # X2T/S2T + T2U model.
+    if config.t2u_config is not None:
         key_map.update(
             {
                 # fmt: off
-                # Text Decoder
-                fr"^{decoder_key}\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
-                fr"^{decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
-                fr"^{decoder_key}\.layer_norm\.":                                r"text_decoder.layer_norm.",
-                fr"^{decoder_key}\.output_projection\.":                         r"final_proj.",
+                # T2U Encoder
+                fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
+                fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
+                fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
+                fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
+                fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
+                fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
+                fr"^{t2u_encoder_key}\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
+
+                # T2U Decoder frontend
+                fr"^{t2u_decoder_key}\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
+                fr"^{t2u_decoder_key}\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
+                fr"^{t2u_decoder_key}\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
+                fr"^{t2u_decoder_key}\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
+                fr"^{t2u_decoder_key}\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
+                fr"^{t2u_decoder_key}\.char_upsampler\.pos_emb_alpha":                 r"t2u_model.decoder_frontend.pos_emb_alpha_char",
+
+                # T2U Decoder
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
+                fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
+                fr"^{t2u_decoder_key}\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
+                fr"^{t2u_decoder_key}\.output_projection\.":                         r"t2u_model.final_proj.",
                 # fmt: on
             }
         )
-        # ExpressiveUnitY model (from multi_arch codebase)
-        if config.prosody_encoder_config is not None:
-            key_map.update(
-                {
-                    # fmt: off
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.film\.":              r"t2u_model.decoder.layers.\1.film.",
-                    fr"^{ecapa_tdnn_key}\.":                                       r"prosody_encoder_model.",
-                    r"^t2s_model\.global_proj_enc\.":                             r"t2u_model.prosody_proj.",
-                    # fmt: on
-                }
-            )
-
-        # X2T/S2T + T2U model.
-        if config.t2u_config is not None:
-            key_map.update(
-                {
-                    # fmt: off
-                    # T2U Encoder
-                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
-                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
-                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
-                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
-                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
-                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
-                    fr"^{t2u_encoder_key}\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
-
-                    # T2U Decoder frontend
-                    fr"^{t2u_decoder_key}\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
-                    fr"^{t2u_decoder_key}\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
-                    fr"^{t2u_decoder_key}\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
-                    fr"^{t2u_decoder_key}\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
-                    fr"^{t2u_decoder_key}\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
-                    fr"^{t2u_decoder_key}\.char_upsampler\.pos_emb_alpha":                 r"t2u_model.decoder_frontend.pos_emb_alpha_char",
-
-                    # T2U Decoder
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
-                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
-                    fr"^{t2u_decoder_key}\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
-                    fr"^{t2u_decoder_key}\.output_projection\.":                         r"t2u_model.final_proj.",
-                    # fmt: on
-                }
-            )
-
-        return key_map
-
-
-load_unity_model = UnitYLoader(
+
+    return key_map
+
+
+load_unity_config = ConfigLoader[UnitYConfig](asset_store, unity_archs)
+
+
+load_unity_model = ModelLoader[UnitYModel, UnitYConfig](
     asset_store,
     download_manager,
+    load_unity_config,
     create_unity_model,
-    unity_archs,
+    convert_unity_checkpoint,
     restrict_checkpoints=False,
 )
 
 
-load_unity_config = ModelConfigLoader[UnitYConfig](asset_store, unity_archs)
-
-
 load_unity_text_tokenizer = NllbTokenizerLoader(asset_store, download_manager)
 
 

+ 1 - 1
src/seamless_communication/models/unity/t2u_builder.py

@@ -134,7 +134,7 @@ class UnitYT2UConfig:
 unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
 
 
-unity_t2u_arch = unity_t2u_archs.marker
+unity_t2u_arch = unity_t2u_archs.decorator
 
 
 @unity_t2u_arch("base")

+ 0 - 4
src/seamless_communication/models/vocoder/__init__.py

@@ -15,10 +15,6 @@ from seamless_communication.models.vocoder.codehifigan import (
     CodeGenerator as CodeGenerator,
 )
 from seamless_communication.models.vocoder.hifigan import Generator as Generator
-from seamless_communication.models.vocoder.loader import (
-    MelVocoderLoader as MelVocoderLoader,
-)
-from seamless_communication.models.vocoder.loader import VocoderLoader as VocoderLoader
 from seamless_communication.models.vocoder.loader import (
     load_mel_vocoder_model as load_mel_vocoder_model,
 )

+ 4 - 4
src/seamless_communication/models/vocoder/builder.py

@@ -32,13 +32,12 @@ class VocoderConfig:
     num_langs: int
     spkr_embedding_dim: int
     num_spkrs: int
-    lang_spkr_idx_map: dict
+    lang_spkr_idx_map: Dict
 
 
 vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_code_hifigan")
 
-
-vocoder_arch = vocoder_archs.marker
+vocoder_arch = vocoder_archs.decorator
 
 
 @vocoder_arch("base")
@@ -139,7 +138,8 @@ def create_vocoder_model(
 
 
 mel_vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_mel_hifigan")
-mel_vocoder_arch = mel_vocoder_archs.marker
+
+mel_vocoder_arch = mel_vocoder_archs.decorator
 
 
 @mel_vocoder_arch("base_mel")

+ 31 - 35
src/seamless_communication/models/vocoder/loader.py

@@ -3,11 +3,11 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-from typing import Any, Mapping, final
+
+from typing import Any, Mapping
 
 from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils.model_loader import ModelLoader
-from overrides import override as finaloverride
+from fairseq2.models.utils import ConfigLoader, ModelLoader
 
 from seamless_communication.models.vocoder.builder import (
     VocoderConfig,
@@ -20,44 +20,40 @@ from seamless_communication.models.vocoder.melhifigan import MelGenerator
 from seamless_communication.models.vocoder.vocoder import Vocoder
 
 
-@final
-class VocoderLoader(ModelLoader[Vocoder, VocoderConfig]):
-    """Loads Vocoder models."""
-
-    @finaloverride
-    def _convert_checkpoint(
-        self, checkpoint: Mapping[str, Any], config: VocoderConfig
-    ) -> Mapping[str, Any]:
-        if (
-            "model" in checkpoint
-            and "code_generator.resblocks.0.convs1.0.weight_g" in checkpoint["model"]
-        ):
-            return checkpoint
-
-        old_state_dict = checkpoint["generator"]
-        new_state_dict = {}
-        for key in old_state_dict:
-            new_key = f"code_generator.{key}"
-            new_state_dict[new_key] = old_state_dict[key]
-        checkpoint["model"] = new_state_dict
-        del checkpoint["generator"]  # type: ignore
+def convert_vocoder_checkpoint(
+    checkpoint: Mapping[str, Any], config: VocoderConfig
+) -> Mapping[str, Any]:
+    if (
+        "model" in checkpoint
+        and "code_generator.resblocks.0.convs1.0.weight_g" in checkpoint["model"]
+    ):
         return checkpoint
 
+    old_state_dict = checkpoint["generator"]
+    new_state_dict = {}
+    for key in old_state_dict:
+        new_key = f"code_generator.{key}"
+        new_state_dict[new_key] = old_state_dict[key]
+    checkpoint["model"] = new_state_dict
+    del checkpoint["generator"]  # type: ignore
+    return checkpoint
+
+
+load_vocoder_config = ConfigLoader[VocoderConfig](asset_store, vocoder_archs)
 
-load_vocoder_model = VocoderLoader(
-    asset_store, download_manager, create_vocoder_model, vocoder_archs
+
+load_vocoder_model = ModelLoader[Vocoder, VocoderConfig](
+    asset_store,
+    download_manager,
+    load_vocoder_config,
+    create_vocoder_model,
+    convert_vocoder_checkpoint,
 )
 
 
-@final
-class MelVocoderLoader(ModelLoader[MelGenerator, VocoderConfig]):
-    @finaloverride
-    def _convert_checkpoint(
-        self, checkpoint: Mapping[str, Any], config: VocoderConfig
-    ) -> Mapping[str, Any]:
-        return checkpoint
+load_mel_vocoder_config = ConfigLoader[VocoderConfig](asset_store, mel_vocoder_archs)
 
 
-load_mel_vocoder_model = MelVocoderLoader(
-    asset_store, download_manager, create_mel_vocoder_model, mel_vocoder_archs
+load_mel_vocoder_model = ModelLoader[MelGenerator, VocoderConfig](
+    asset_store, download_manager, load_mel_vocoder_config, create_mel_vocoder_model
 )

+ 2 - 2
src/seamless_communication/models/vocoder/vocoder.py

@@ -4,7 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 import torch
 import torch.nn as nn
@@ -14,7 +14,7 @@ from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 
 
 class Vocoder(nn.Module):
-    def __init__(self, code_generator: CodeGenerator, lang_spkr_idx_map: dict):
+    def __init__(self, code_generator: CodeGenerator, lang_spkr_idx_map: Dict):
         super(Vocoder, self).__init__()
         self.code_generator = code_generator
         self.lang_spkr_idx_map = lang_spkr_idx_map

+ 1 - 1
src/seamless_communication/models/wav2vec2_chunk/builder.py

@@ -62,7 +62,7 @@ wav2vec2_chunk_archs = ArchitectureRegistry[Wav2Vec2ChunkEncoderConfig](
     "wav2vec2_chunk"
 )
 
-wav2vec2_chunk_arch = wav2vec2_chunk_archs.marker
+wav2vec2_chunk_arch = wav2vec2_chunk_archs.decorator
 
 
 @wav2vec2_chunk_arch("600m")