Can Balioglu 1 жил өмнө
parent
commit
283f74250f
51 өөрчлөгдсөн 389 нэмэгдсэн , 374 устгасан
  1. 1 1
      setup.py
  2. 1 1
      src/seamless_communication/cards/conformer_shaw.yaml
  3. 1 1
      src/seamless_communication/cards/nar_t2u_aligner.yaml
  4. 1 1
      src/seamless_communication/cards/seamless_streaming_monotonic_decoder.yaml
  5. 1 1
      src/seamless_communication/cards/unity_nllb-100.yaml
  6. 1 1
      src/seamless_communication/cards/unity_nllb-200.yaml
  7. 1 1
      src/seamless_communication/cards/vocoder_36langs.yaml
  8. 1 1
      src/seamless_communication/cards/vocoder_pretssel.yaml
  9. 1 1
      src/seamless_communication/cards/vocoder_pretssel_16khz.yaml
  10. 1 1
      src/seamless_communication/cards/vocoder_v2.yaml
  11. 1 1
      src/seamless_communication/cards/xlsr2_1b_v2.yaml
  12. 1 1
      src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py
  13. 3 4
      src/seamless_communication/cli/m4t/evaluate/evaluate.py
  14. 1 1
      src/seamless_communication/cli/m4t/finetune/dataloader.py
  15. 8 8
      src/seamless_communication/cli/m4t/finetune/trainer.py
  16. 1 1
      src/seamless_communication/cli/toxicity/asr_etox.py
  17. 5 3
      src/seamless_communication/inference/generator.py
  18. 4 4
      src/seamless_communication/inference/translator.py
  19. 4 5
      src/seamless_communication/models/aligner/alignment_extractor.py
  20. 8 7
      src/seamless_communication/models/aligner/builder.py
  21. 9 14
      src/seamless_communication/models/aligner/loader.py
  22. 29 10
      src/seamless_communication/models/aligner/model.py
  23. 34 19
      src/seamless_communication/models/conformer_shaw/builder.py
  24. 12 12
      src/seamless_communication/models/conformer_shaw/loader.py
  25. 2 3
      src/seamless_communication/models/generator/builder.py
  26. 2 2
      src/seamless_communication/models/generator/ecapa_tdnn_builder.py
  27. 6 12
      src/seamless_communication/models/generator/loader.py
  28. 1 1
      src/seamless_communication/models/generator/streamable.py
  29. 7 3
      src/seamless_communication/models/generator/vocoder.py
  30. 2 4
      src/seamless_communication/models/monotonic_decoder/builder.py
  31. 9 17
      src/seamless_communication/models/monotonic_decoder/loader.py
  32. 10 7
      src/seamless_communication/models/monotonic_decoder/model.py
  33. 2 2
      src/seamless_communication/models/monotonic_decoder/monotonic_decoder.py
  34. 2 2
      src/seamless_communication/models/monotonic_decoder/monotonic_decoder_layer.py
  35. 2 2
      src/seamless_communication/models/monotonic_decoder/p_choose.py
  36. 2 2
      src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py
  37. 12 29
      src/seamless_communication/models/tokenizer.py
  38. 2 0
      src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py
  39. 62 28
      src/seamless_communication/models/unity/builder.py
  40. 8 33
      src/seamless_communication/models/unity/char_tokenizer.py
  41. 2 2
      src/seamless_communication/models/unity/fft_decoder.py
  42. 3 3
      src/seamless_communication/models/unity/fft_decoder_layer.py
  43. 78 76
      src/seamless_communication/models/unity/loader.py
  44. 13 8
      src/seamless_communication/models/unity/model.py
  45. 2 2
      src/seamless_communication/models/unity/nar_decoder_frontend.py
  46. 5 6
      src/seamless_communication/models/unity/t2u_builder.py
  47. 2 2
      src/seamless_communication/models/vocoder/builder.py
  48. 9 13
      src/seamless_communication/models/vocoder/loader.py
  49. 6 5
      src/seamless_communication/models/vocoder/vocoder.py
  50. 1 2
      src/seamless_communication/toxicity/etox_bad_word_checker.py
  51. 7 8
      src/seamless_communication/toxicity/mintox.py

+ 1 - 1
setup.py

@@ -22,7 +22,7 @@ setup(
     license="Creative Commons",
     install_requires=[
         "datasets",
-        "fairseq2==0.2.*",
+#        "fairseq2==0.2.*",
         "fire",
         "librosa",
         "openai-whisper",

+ 1 - 1
src/seamless_communication/cards/conformer_shaw.yaml

@@ -5,6 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: conformer_shaw
-model_type: wav2vec2
+model_family: conformer_shaw
 model_arch: conformer_shaw_600m
 checkpoint: "https://huggingface.co/facebook/conformer-shaw/resolve/main/conformer_shaw.pt"

+ 1 - 1
src/seamless_communication/cards/nar_t2u_aligner.yaml

@@ -6,7 +6,7 @@
 
 name: nar_t2u_aligner
 char_tokenizer: "https://huggingface.co/facebook/seamless-streaming/resolve/main/spm_char_lang38_tc.model"
-model_type: unity2_aligner
+model_family: unity2_aligner
 model_arch: nar_t2u_aligner
 checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/unity2_aligner.pt"
 num_units: 10000

+ 1 - 1
src/seamless_communication/cards/seamless_streaming_monotonic_decoder.yaml

@@ -5,6 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: seamless_streaming_monotonic_decoder
-model_type: monotonic_decoder
+model_family: monotonic_decoder
 model_arch: dense_1b
 checkpoint: "https://huggingface.co/facebook/seamless-streaming/resolve/main/seamless_streaming_monotonic_decoder.pt"

+ 1 - 1
src/seamless_communication/cards/unity_nllb-100.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: unity_nllb-100
-model_type: unity
+model_family: unity
 tokenizer: "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
 default_lang: eng
 langs:

+ 1 - 1
src/seamless_communication/cards/unity_nllb-200.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: unity_nllb-200
-model_type: unity
+model_family: unity
 tokenizer: "https://huggingface.co/facebook/seamless-m4t-medium/resolve/main/tokenizer.model"
 default_lang: eng
 langs:

+ 1 - 1
src/seamless_communication/cards/vocoder_36langs.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: vocoder_36langs
-model_type: vocoder_code_hifigan
+model_family: vocoder_code_hifigan
 model_arch: base
 checkpoint: "https://huggingface.co/facebook/seamless-m4t-vocoder/resolve/main/vocoder_36langs.pt"
 model_config: {

+ 1 - 1
src/seamless_communication/cards/vocoder_pretssel.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: vocoder_pretssel
-model_type: vocoder_pretssel
+model_family: vocoder_pretssel
 model_arch: 24khz
 checkpoint: "https://github.com/facebookresearch/seamless_communication;gated=true"
 sample_rate: 24000

+ 1 - 1
src/seamless_communication/cards/vocoder_pretssel_16khz.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: vocoder_pretssel_16khz
-model_type: vocoder_pretssel
+model_family: vocoder_pretssel
 model_arch: 16khz
 checkpoint: "https://github.com/facebookresearch/seamless_communication;gated=true"
 sample_rate: 16000

+ 1 - 1
src/seamless_communication/cards/vocoder_v2.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: vocoder_v2
-model_type: vocoder_code_hifigan
+model_family: vocoder_code_hifigan
 model_arch: base
 checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/vocoder_v2.pt"
 model_config: {

+ 1 - 1
src/seamless_communication/cards/xlsr2_1b_v2.yaml

@@ -5,6 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 name: xlsr2_1b_v2
-model_type: wav2vec2
+model_family: wav2vec2
 model_arch: xlsr2_1b_v2
 checkpoint: "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/xlsr2_1b_v2.pt"

+ 1 - 1
src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py

@@ -14,7 +14,7 @@ logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser(
         description="Convert raw audio to units (and optionally audio) using UnitExtractor."
     )

+ 3 - 4
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -19,7 +19,6 @@ import torchaudio
 from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
-from fairseq2.data.typing import StringLike
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 from tqdm import tqdm
@@ -181,10 +180,10 @@ def build_data_pipeline(
 
 def adjust_output_for_corrupted_inputs(
     valid_sequences: Tensor,
-    text_output: List[StringLike],
+    text_output: List[str],
     speech_output: Optional[BatchedSpeechOutput],
-) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
-    adjusted_text_output: List[StringLike] = []
+) -> Tuple[List[str], Optional[BatchedSpeechOutput]]:
+    adjusted_text_output: List[str] = []
     adjusted_speech_output: Optional[BatchedSpeechOutput] = None
 
     if speech_output is not None:

+ 1 - 1
src/seamless_communication/cli/m4t/finetune/dataloader.py

@@ -83,7 +83,7 @@ class BatchingConfig:
     """Select between fp16/fp32 for float tensors """
 
 
-def worker_init_fn(worker_id):
+def worker_init_fn(worker_id) -> None:
     np.random.seed(np.random.get_state()[1][0] + worker_id)
 
 

+ 8 - 8
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -116,12 +116,12 @@ class UnitYFinetuneWrapper(nn.Module):
                 unit_encoder_out,
                 unit_encoder_padding_mask,
             ) = self.model.t2u_model.encode(
-                text_decoder_output=text_decoder_out,
-                text_decoder_padding_mask=text_decoder_padding_mask,
+                text_decoder_out,
+                text_decoder_padding_mask,
             )
             seqs = batch.text_to_units.prev_output_tokens.to(self.device)
             seq_lens = batch.text_to_units.target_lengths.to(self.device)
-            unit_decoder_out, _ = self.model.t2u_model.decode(
+            unit_decoder_out = self.model.t2u_model.decode(
                 seqs=seqs,
                 padding_mask=PaddingMask(seq_lens, seqs.size(1)),
                 encoder_output=unit_encoder_out,
@@ -156,7 +156,7 @@ class CalcLoss:
             text_logits.device
         )
         s2t_loss = SequenceModelOutput(
-            logits=text_logits, vocab_info=self.s2t_vocab_info
+            text_logits, self.s2t_vocab_info.pad_idx
         ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             ignore_prefix_size=1,
@@ -167,7 +167,7 @@ class CalcLoss:
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
-            logits=unit_logits, vocab_info=self.t2u_vocab_info
+            logits=unit_logits, vocab_info=self.t2u_vocab_info.pad_idx
         ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             ignore_prefix_size=1,
@@ -314,7 +314,7 @@ class UnitYFinetune:
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
 
-    def _train_step_log(self):
+    def _train_step_log(self) -> None:
         """Log train stats"""
         if (self.update_idx + 1) % self.params.log_steps == 0:
             avg_loss = self.train_loss_hist.reduce()
@@ -340,7 +340,7 @@ class UnitYFinetune:
         self.train_loss_hist.update(1, loss.item())
         self._train_step_log()
 
-    def _save_model(self):
+    def _save_model(self) -> None:
         logger.info("Saving model")
         if dist_utils.is_main_process():
             state_dict = {
@@ -351,7 +351,7 @@ class UnitYFinetune:
         if dist_utils.is_dist_initialized():
             dist.barrier()
 
-    def run(self):
+    def run(self) -> None:
         logger.info("Start finetuning")
         self._reset_stats()
         self._eval_model()

+ 1 - 1
src/seamless_communication/cli/toxicity/asr_etox.py

@@ -207,7 +207,7 @@ def build_data_pipeline(
 
     pipeline_builder = read_text(data_file, rtrim=True).skip(1).map(split_tsv)
 
-    map_file = FileMapper(root_dir=audio_root_dir, cached_fd_count=10)
+    map_file = FileMapper(root_dir=Path(audio_root_dir), cached_fd_count=10)
 
     pipeline_builder.map(
         map_file,

+ 5 - 3
src/seamless_communication/inference/generator.py

@@ -8,7 +8,7 @@ from dataclasses import dataclass
 from typing import List, Optional, Tuple
 
 import torch
-from fairseq2.data import SequenceData, StringLike
+from fairseq2.data import SequenceData
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
     BeamSearchSeq2SeqGenerator,
@@ -137,6 +137,7 @@ class UnitYGenerator:
             decoder_frontend=model.text_decoder_frontend,
             decoder=model.text_decoder,
             final_proj=model.final_proj,
+            max_target_seq_len=model.max_target_seq_len,
             target_vocab_info=model.target_vocab_info,
         )
 
@@ -169,6 +170,7 @@ class UnitYGenerator:
                 decoder_frontend=model.text_decoder_frontend,
                 decoder=model.text_decoder,
                 final_proj=model.final_proj,
+                max_target_seq_len=model.max_target_seq_len,
                 target_vocab_info=model.target_vocab_info,
             )
             generator = BeamSearchSeq2SeqGenerator(
@@ -234,7 +236,7 @@ class UnitYGenerator:
         ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
-    ) -> Tuple[List[StringLike], Optional[Tensor]]:
+    ) -> Tuple[List[str], Optional[Tensor]]:
         """
         :param source_seqs:
             The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
@@ -346,7 +348,7 @@ class UnitYGenerator:
             unit_seqs = t2u_model_output.logits.argmax(dim=2)
             # Apply the padding mask to the generated units.
             unit_seqs = apply_padding_mask(
-                unit_seqs, decoder_padding_mask, t2u_model_output.vocab_info.pad_idx
+                unit_seqs, decoder_padding_mask, t2u_model_output.pad_idx
             )
 
         # Convert to speech units.

+ 4 - 4
src/seamless_communication/inference/translator.py

@@ -13,7 +13,7 @@ import torch
 import torch.nn as nn
 from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
-from fairseq2.data import Collater, SequenceData, StringLike
+from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import TextTokenizer
 from fairseq2.memory import MemoryBlock
@@ -169,7 +169,7 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
-    ) -> Tuple[List[StringLike], Optional[Tensor]]:
+    ) -> Tuple[List[str], Optional[Tensor]]:
         # We disregard unit generations opts for the NAR T2U decoder.
         if output_modality != Modality.SPEECH or isinstance(
             model.t2u_model, UnitYNART2UModel
@@ -226,8 +226,8 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
-        src_text: Optional[StringLike] = None,
-    ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
+        src_text: Optional[str] = None,
+    ) -> Tuple[List[str], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
 

+ 4 - 5
src/seamless_communication/models/aligner/alignment_extractor.py

@@ -12,7 +12,6 @@ import torch
 import torch.nn as nn
 import torchaudio
 from fairseq2.typing import DataType, Device
-from fairseq2.data.typing import StringLike
 from torch import Tensor
 
 from seamless_communication.models.aligner.loader import load_unity2_alignment_model
@@ -82,7 +81,7 @@ class AlignmentExtractor(nn.Module):
             audio = audio.mean(0)
         assert (
             audio.ndim == 1
-        ), f"After channel averaging audio shape expected to be [Time] i.e. mono audio"
+        ), "After channel averaging audio shape expected to be [Time] i.e. mono audio"
         audio = audio.to(self.device, self.dtype)
 
         return audio
@@ -101,7 +100,7 @@ class AlignmentExtractor(nn.Module):
         text: str,
         plot: bool = False,
         add_trailing_silence: bool = False,
-    ) -> Tuple[Tensor, Tensor, List[StringLike]]:
+    ) -> Tuple[Tensor, Tensor, List[str]]:
         if isinstance(audio, Tensor) and not torch.is_floating_point(audio):
             # we got units as audio arg
             units = audio
@@ -137,11 +136,11 @@ class AlignmentExtractor(nn.Module):
 
         return alignment_durations, tokenized_text_ids, tokenized_text_tokens
 
-    def detokenize_text(self, tokenized_text_ids: Tensor) -> StringLike:
+    def detokenize_text(self, tokenized_text_ids: Tensor) -> str:
         return self.alignment_model.alignment_frontend.decode_text(tokenized_text_ids)
 
     def plot_alignment(
-        self, audio: Tensor, text_tokens: List[StringLike], durations: Tensor
+        self, audio: Tensor, text_tokens: List[str], durations: Tensor
     ) -> None:
         if not matplotlib_available:
             raise RuntimeError(

+ 8 - 7
src/seamless_communication/models/aligner/builder.py

@@ -10,9 +10,9 @@ from typing import Optional, Union
 import torch
 from fairseq2.assets.card import AssetCard
 from fairseq2.data.vocabulary_info import VocabularyInfo
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
-from fairseq2.typing import DataType, Device
+from fairseq2.typing import CPU, DataType, Device
 
 from seamless_communication.models.aligner.model import (
     UnitY2AlignmentEncoder,
@@ -56,7 +56,7 @@ class UnitY2AlignmentConfig:
     alignment_frontend_config: UnitY2AlignmentFrontendConfig
 
 
-aligner_archs = ArchitectureRegistry[UnitY2AlignmentConfig]("unity2_aligner")
+aligner_archs = ModelArchitectureRegistry[UnitY2AlignmentConfig]()
 
 aligner_arch = aligner_archs.decorator
 
@@ -90,14 +90,14 @@ def _aligner_nar_t2u() -> UnitY2AlignmentConfig:
 class UnitY2AlignmentBuilder:
     config: UnitY2AlignmentConfig
     device: Optional[Device]
-    dtype: DataType
+    dtype: Optional[DataType]
 
     def __init__(
         self,
         config: UnitY2AlignmentConfig,
         *,
         device: Optional[Device] = None,
-        dtype: DataType = torch.float32,
+        dtype: Optional[DataType] = torch.float32,
     ) -> None:
         """
         :param config:
@@ -155,7 +155,8 @@ class UnitY2AlignmentBuilder:
             dropout=cfg.dropout,
             temperature=cfg.temperature,
             reduction_factor=cfg.reduction_factor,
-            dtype=self.dtype,
+            device=self.device or CPU,
+            dtype=self.dtype or torch.float32,
         )
         alignment_encoder.training = training
 
@@ -165,7 +166,7 @@ class UnitY2AlignmentBuilder:
 def create_unity2_alignment_model(
     config: UnitY2AlignmentConfig,
     device: Optional[Device] = None,
-    dtype: DataType = torch.float32,
+    dtype: Optional[DataType] = torch.float32,
 ) -> UnitY2AlignmentModel:
     """Create a UnitY model.
 

+ 9 - 14
src/seamless_communication/models/aligner/loader.py

@@ -4,24 +4,23 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, List, Mapping
+from typing import Any, List, Dict
 
 import torch
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 
 from seamless_communication.models.aligner.builder import (
     UnitY2AlignmentConfig,
     aligner_archs,
     create_unity2_alignment_model,
 )
-from seamless_communication.models.aligner.model import UnitY2AlignmentModel
+from seamless_communication.models.aligner.model import UNITY2_ALIGNER_FAMILY
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 
 
 def convert_unity2_aligner_checkpoint(
-    checkpoint: Mapping[str, Any], config: UnitY2AlignmentConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: UnitY2AlignmentConfig
+) -> Dict[str, Any]:
     if (
         "model" in checkpoint
         and "alignment_encoder.t_conv.1.weight" in checkpoint["model"]
@@ -74,15 +73,11 @@ def _get_char_index_mapping(config: UnitY2AlignmentConfig) -> List[int]:
     return model_to_dict_mapping
 
 
-load_unity2_alignment_config = ConfigLoader[UnitY2AlignmentConfig](
-    asset_store, aligner_archs
-)
-
-load_unity2_alignment_model = ModelLoader[UnitY2AlignmentModel, UnitY2AlignmentConfig](
-    asset_store,
-    download_manager,
-    load_unity2_alignment_config,
+load_unity2_alignment_model, load_unity2_alignment_config = setup_model_family(
+    UNITY2_ALIGNER_FAMILY,
+    UnitY2AlignmentConfig,
     create_unity2_alignment_model,
+    aligner_archs,
     convert_unity2_aligner_checkpoint,
     restrict_checkpoints=False,
 )

+ 29 - 10
src/seamless_communication/models/aligner/model.py

@@ -4,25 +4,27 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, List, Tuple, Union
+from typing import Any, Final, List, Tuple, Union
 
 import numpy as np
 import numpy.typing as npt
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from fairseq2.data import CString
+from fairseq2.models import Model
 from fairseq2.nn.embedding import StandardEmbedding
 from fairseq2.nn.padding import to_padding_mask
-from fairseq2.typing import DataType
+from fairseq2.typing import DataType, Device
 from torch import Tensor
 from torch.nn import Module
 
 from seamless_communication.models.unity.char_tokenizer import CharTokenizer
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 
+UNITY2_ALIGNER_FAMILY: Final = "unity2_aligner"
 
-class UnitY2AlignmentFrontend(Module):
+
+class UnitY2AlignmentFrontend(nn.Module):
     def __init__(
         self,
         embed_text: StandardEmbedding,
@@ -53,7 +55,7 @@ class UnitY2AlignmentFrontend(Module):
 
     def tokenize_text_to_tokens(
         self, text: str, add_trailing_silence: bool = False
-    ) -> List[Union[CString, str]]:
+    ) -> List[str]:
         tokenized = self.encode_text.encode_as_tokens(text)
         if add_trailing_silence:
             tokenized = tokenized + [tokenized[0]]
@@ -90,6 +92,7 @@ class UnitY2AlignmentEncoder(Module):
         dropout: float,
         temperature: float,
         reduction_factor: int,
+        device: Device,
         dtype: DataType,
     ):
         super().__init__()
@@ -101,7 +104,12 @@ class UnitY2AlignmentEncoder(Module):
             if i < text_layers - 1:
                 layers.append(
                     nn.Conv1d(
-                        embed_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
+                        embed_dim,
+                        embed_dim,
+                        kernel_size=3,
+                        padding=1,
+                        device=device,
+                        dtype=dtype,
                     )
                 )
                 layers.append(nn.ReLU())
@@ -109,7 +117,12 @@ class UnitY2AlignmentEncoder(Module):
             else:
                 layers.append(
                     nn.Conv1d(
-                        embed_dim, embed_dim, kernel_size=1, padding=0, dtype=dtype
+                        embed_dim,
+                        embed_dim,
+                        kernel_size=1,
+                        padding=0,
+                        device=device,
+                        dtype=dtype,
                     )
                 )
                 layers.append(nn.Dropout(p=dropout))
@@ -122,7 +135,12 @@ class UnitY2AlignmentEncoder(Module):
             if i < feat_layers - 1:
                 layers.append(
                     nn.Conv1d(
-                        input_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
+                        input_dim,
+                        embed_dim,
+                        kernel_size=3,
+                        padding=1,
+                        device=device,
+                        dtype=dtype,
                     )
                 )
                 layers.append(nn.ReLU())
@@ -135,6 +153,7 @@ class UnitY2AlignmentEncoder(Module):
                         kernel_size=1,
                         padding=0,
                         stride=reduction_factor,
+                        device=device,
                         dtype=dtype,
                     )
                 )
@@ -277,7 +296,7 @@ def viterbi_decode(
     return durations
 
 
-class UnitY2AlignmentModel(Module):
+class UnitY2AlignmentModel(Model):
     alignment_encoder: UnitY2AlignmentEncoder
     alignment_frontend: UnitY2AlignmentFrontend
 
@@ -286,7 +305,7 @@ class UnitY2AlignmentModel(Module):
         alignment_frontend: UnitY2AlignmentFrontend,
         alignment_encoder: UnitY2AlignmentEncoder,
     ):
-        super().__init__()
+        super().__init__(UNITY2_ALIGNER_FAMILY)
         self.alignment_frontend = alignment_frontend
         self.alignment_encoder = alignment_encoder
 

+ 34 - 19
src/seamless_communication/models/conformer_shaw/builder.py

@@ -4,13 +4,13 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from dataclasses import asdict, dataclass
-from typing import Optional
+from dataclasses import asdict, dataclass, field
+from typing import Final, Optional
 
 from fairseq2.models.conformer import ConformerConvolution
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.models.w2vbert import w2vbert_archs
-from fairseq2.models.wav2vec2.builder import (
+from fairseq2.models.wav2vec2 import (
     Wav2Vec2Builder,
     Wav2Vec2Config,
     Wav2Vec2EncoderBuilder,
@@ -21,15 +21,17 @@ from fairseq2.models.wav2vec2.model import Wav2Vec2Model
 from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA, create_default_sdpa
 from fairseq2.typing import DataType, Device
 
+CONFORMER_SHAW_FAMILY: Final = "conformer_shaw"
+
 
 @dataclass
 class ShawRelativePositionSDPAConfig:
     """Holds the configuration of the :class:ShawRelativePositionSDPA module."""
 
-    max_left_rel_pos: int
+    max_left_rel_pos: int = 64
     """The left clipping value for relative positions."""
 
-    max_right_rel_pos: Optional[int]
+    max_right_rel_pos: Optional[int] = 8
     """The right clipping value for relative positions."""
 
     use_rel_pos_values: bool = False
@@ -40,18 +42,23 @@ class ShawRelativePositionSDPAConfig:
 class ConformerShawEncoderConfig(Wav2Vec2EncoderConfig):
     """Holds the configuration of a conformer shaw encoder."""
 
-    shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
+    shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig] = None
     """The parameters for ShawRelativePositionSDPA."""
 
 
-conformer_shaw_archs = ArchitectureRegistry[ConformerShawEncoderConfig](
-    "conformer_shaw"
-)
+@dataclass
+class ConformerShawConfig(Wav2Vec2Config):
+    """Holds the configuration of a conformer shaw model."""
+
+    encoder_config: ConformerShawEncoderConfig = field(
+        default_factory=ConformerShawEncoderConfig
+    )
 
-conformer_shaw_arch = conformer_shaw_archs.decorator
 
+conformer_shaw_archs = ModelArchitectureRegistry[ConformerShawConfig]()
+
+conformer_shaw_arch = conformer_shaw_archs.decorator
 
-@conformer_shaw_arch("600m")
 def _conformer_shaw_600m_encoder() -> ConformerShawEncoderConfig:
     w2vbert_config = w2vbert_archs.get_config("600m")
     w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
@@ -68,18 +75,20 @@ def _conformer_shaw_600m_encoder() -> ConformerShawEncoderConfig:
     return conformer_shaw_encoder_config
 
 
-@wav2vec2_arch("conformer_shaw_600m")
-def _conformer_shaw_600m() -> Wav2Vec2Config:
+@conformer_shaw_arch("conformer_shaw_600m")
+def _conformer_shaw_600m() -> ConformerShawConfig:
     encoder_config = _conformer_shaw_600m_encoder()
 
-    return Wav2Vec2Config(
+    return ConformerShawConfig(
         encoder_config,
         final_dim=768,
         final_proj_bias=True,
         temporal_mask_span_len=10,
         max_temporal_mask_prob=0.65,
+        min_num_temporal_mask_spans=2,
         spatial_mask_span_len=10,
         max_spatial_mask_prob=0.0,
+        min_num_spatial_mask_spans=2,
         quantized_dim=768,
         num_codebooks=2,
         num_codebook_entries=320,
@@ -101,6 +110,8 @@ class ConformerShawEncoderBuilder(Wav2Vec2EncoderBuilder):
     """
 
     config: ConformerShawEncoderConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
 
     def __init__(
         self,
@@ -119,11 +130,15 @@ class ConformerShawEncoderBuilder(Wav2Vec2EncoderBuilder):
         """
         super().__init__(config, device=device, dtype=dtype)
 
+        self.config = config
+
         assert self.config.use_conformer, "This architecture only supports a Conformer."
         assert (
             self.config.pos_encoder_type == "shaw_relative"
         ), "This architecture only supports ShawRelativePositionSDPA."
 
+        self.device, self.dtype = device, dtype
+
     def build_sdpa(self) -> SDPA:
         if self.config.shaw_rel_pos_sdpa_config is None:
             raise ValueError(
@@ -157,7 +172,7 @@ class ConformerShawEncoderBuilder(Wav2Vec2EncoderBuilder):
 
 
 def create_conformer_shaw_model(
-    config: Wav2Vec2Config,
+    config: ConformerShawConfig,
     *,
     device: Optional[Device] = None,
     dtype: Optional[DataType] = None,
@@ -171,12 +186,12 @@ def create_conformer_shaw_model(
     :param dtype:
         The data type of module parameters and buffers.
     """
-    assert isinstance(config.encoder_config, ConformerShawEncoderConfig)
-
     encoder_builder = ConformerShawEncoderBuilder(
         config.encoder_config, device=device, dtype=dtype
     )
 
-    builder = Wav2Vec2Builder(config, encoder_builder, device=device, dtype=dtype)
+    builder = Wav2Vec2Builder(
+        CONFORMER_SHAW_FAMILY, config, encoder_builder, device=device, dtype=dtype
+    )
 
     return builder.build_model()

+ 12 - 12
src/seamless_communication/models/conformer_shaw/loader.py

@@ -4,25 +4,25 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, Mapping
+from typing import Any, Dict
 
 import torch
 
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ModelLoader
+from fairseq2.models import setup_model_family
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
-from fairseq2.models.wav2vec2.builder import Wav2Vec2Config
-from fairseq2.models.wav2vec2.loader import load_wav2vec2_config
-from fairseq2.models.wav2vec2.model import Wav2Vec2Model
+from fairseq2.models.wav2vec2 import Wav2Vec2Model
 
 from seamless_communication.models.conformer_shaw.builder import (
+    CONFORMER_SHAW_FAMILY,
+    ConformerShawConfig,
+    conformer_shaw_archs,
     create_conformer_shaw_model,
 )
 
 
 def convert_conformer_shaw_checkpoint(
-    checkpoint: Mapping[str, Any], config: Wav2Vec2Config
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: ConformerShawConfig
+) -> Dict[str, Any]:
     """Convert a fairseq conformer shaw checkpoint to fairseq2."""
     state_dict = checkpoint["model"]
 
@@ -73,10 +73,10 @@ def convert_conformer_shaw_checkpoint(
     return convert_fairseq_checkpoint(checkpoint, key_map)
 
 
-load_conformer_shaw_model = ModelLoader[Wav2Vec2Model, Wav2Vec2Config](
-    asset_store,
-    download_manager,
-    load_wav2vec2_config,
+load_conformer_shaw_model, load_conformer_shaw_config = setup_model_family(
+    CONFORMER_SHAW_FAMILY,
+    ConformerShawConfig,
     create_conformer_shaw_model,
+    conformer_shaw_archs,
     convert_conformer_shaw_checkpoint,
 )

+ 2 - 3
src/seamless_communication/models/generator/builder.py

@@ -8,7 +8,7 @@ from dataclasses import dataclass
 from typing import Any, Dict, List, Literal, Optional, Tuple
 
 from fairseq2.data import VocabularyInfo
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import Linear
@@ -110,8 +110,7 @@ class VocoderConfig:
     gcmvn_stats: Dict[str, List]  # type: ignore[type-arg]
 
 
-vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_pretssel")
-
+vocoder_archs = ModelArchitectureRegistry[VocoderConfig]()
 
 vocoder_arch = vocoder_archs.decorator
 

+ 2 - 2
src/seamless_communication/models/generator/ecapa_tdnn_builder.py

@@ -7,7 +7,7 @@
 from dataclasses import dataclass
 from typing import List, Optional
 
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.typing import DataType, Device
 
 from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
@@ -27,7 +27,7 @@ class EcapaTDNNConfig:
     input_dim: int
 
 
-ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+ecapa_tdnn_archs = ModelArchitectureRegistry[EcapaTDNNConfig]()
 
 ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
 

+ 6 - 12
src/seamless_communication/models/generator/loader.py

@@ -5,25 +5,19 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Mapping
-
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 
+from seamless_communication.models.generator.vocoder import PRETSSEL_VOCODER_FAMILY
 from seamless_communication.models.generator.builder import (
     VocoderConfig,
     create_vocoder_model,
     vocoder_archs,
 )
-from seamless_communication.models.generator.vocoder import PretsselVocoder
-
-load_pretssel_vocoder_config = ConfigLoader[VocoderConfig](asset_store, vocoder_archs)
-
 
-load_pretssel_vocoder_model = ModelLoader[PretsselVocoder, VocoderConfig](
-    asset_store,
-    download_manager,
-    load_pretssel_vocoder_config,
+load_pretssel_vocoder_model, load_pretssel_vocoder_config = setup_model_family(
+    PRETSSEL_VOCODER_FAMILY,
+    VocoderConfig,
     create_vocoder_model,
+    vocoder_archs,
     restrict_checkpoints=False,
 )

+ 1 - 1
src/seamless_communication/models/generator/streamable.py

@@ -6,7 +6,7 @@
 
 import math
 import warnings
-from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar
+from typing import Any, Dict, List, Literal, Optional, Tuple
 
 import torch
 from fairseq2.typing import DataType, Device

+ 7 - 3
src/seamless_communication/models/generator/vocoder.py

@@ -4,10 +4,11 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, List, Literal, Optional, Tuple
+from typing import Any, Dict, Final, List, Literal, Optional, Tuple
 
 import torch
 import torch.nn.functional as F
+from fairseq2.models import Model
 from fairseq2.nn.embedding import Embedding, StandardEmbedding
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder
@@ -44,6 +45,9 @@ from .streamable import (
     StreamableResnetBlock,
 )
 
+
+PRETSSEL_VOCODER_FAMILY: Final = "vocoder_pretssel"
+
 ELU_PARAMS: Dict[str, Any] = {"alpha": 1.0}
 
 
@@ -162,7 +166,7 @@ class PretsselDecoderFrontend(Module):
         return seqs, padding_mask
 
 
-class PretsselVocoder(Module):
+class PretsselVocoder(Model):
     """The expressivity-preserving vocoder"""
 
     encoder_frontend: PretsselEncoderFrontend
@@ -212,7 +216,7 @@ class PretsselVocoder(Module):
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ):
-        super().__init__()
+        super().__init__(PRETSSEL_VOCODER_FAMILY)
         self.encoder_frontend = encoder_frontend
         self.encoder = encoder
         self.decoder_frontend = decoder_frontend

+ 2 - 4
src/seamless_communication/models/monotonic_decoder/builder.py

@@ -12,7 +12,7 @@ from fairseq2.models.transformer import (
     TransformerEmbeddingFrontend,
     TransformerFrontend,
 )
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import TiedProjection
@@ -77,9 +77,7 @@ class MonotonicDecoderConfig:
     in the PChooseLayer."""
 
 
-monotonic_decoder_archs = ArchitectureRegistry[MonotonicDecoderConfig](
-    "monotonic_decoder"
-)
+monotonic_decoder_archs = ModelArchitectureRegistry[MonotonicDecoderConfig]()
 
 monotonic_decoder_arch = monotonic_decoder_archs.decorator
 

+ 9 - 17
src/seamless_communication/models/monotonic_decoder/loader.py

@@ -4,11 +4,10 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, Mapping
+from typing import Any, Dict
 
 import torch
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 
 from seamless_communication.models.monotonic_decoder.builder import (
@@ -16,12 +15,12 @@ from seamless_communication.models.monotonic_decoder.builder import (
     create_monotonic_decoder_model,
     monotonic_decoder_archs,
 )
-from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
+from seamless_communication.models.monotonic_decoder.model import MONOTONIC_DECODER_FAMILY
 
 
 def convert_monotonic_checkpoint(
-    checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: MonotonicDecoderConfig
+) -> Dict[str, Any]:
     state_dict = checkpoint["model"]
 
     # Check if we have a fairseq2 checkpoint.
@@ -75,18 +74,11 @@ def convert_monotonic_checkpoint(
     return checkpoint
 
 
-load_monotonic_decoder_config = ConfigLoader[MonotonicDecoderConfig](
-    asset_store, monotonic_decoder_archs
-)
-
-
-load_monotonic_decoder_model = ModelLoader[
-    MonotonicDecoderModel, MonotonicDecoderConfig
-](
-    asset_store,
-    download_manager,
-    load_monotonic_decoder_config,
+load_monotonic_decoder_model, load_monotonic_decoder_config = setup_model_family(
+    MONOTONIC_DECODER_FAMILY,
+    MonotonicDecoderConfig,
     create_monotonic_decoder_model,
+    monotonic_decoder_archs,
     convert_monotonic_checkpoint,
     restrict_checkpoints=False,
 )

+ 10 - 7
src/seamless_communication/models/monotonic_decoder/model.py

@@ -4,23 +4,26 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Optional, Tuple, final
+from typing import Final, Optional, Tuple, final
 
+from fairseq2.models import Model
 from fairseq2.models.transformer.frontend import TransformerFrontend
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.projection import Projection
-from overrides import final as finaloverride
+from overrides import final as override
 from torch import Tensor
-from torch.nn import Module
 
 from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
     MonotonicTransformerDecoder,
 )
 
 
+MONOTONIC_DECODER_FAMILY: Final = "monotonic_decoder"
+
+
 @final
-class MonotonicDecoderModel(Module):
+class MonotonicDecoderModel(Model):
     text_decoder_frontend: TransformerFrontend
     text_decoder: MonotonicTransformerDecoder
     final_proj: Projection
@@ -31,13 +34,13 @@ class MonotonicDecoderModel(Module):
         text_decoder: MonotonicTransformerDecoder,
         final_proj: Projection,
     ) -> None:
-        super().__init__()
+        super().__init__(MONOTONIC_DECODER_FAMILY)
 
         self.text_decoder_frontend = text_decoder_frontend
         self.text_decoder = text_decoder
         self.final_proj = final_proj
 
-    @finaloverride
+    @override
     def decode(
         self,
         seqs: Tensor,
@@ -59,7 +62,7 @@ class MonotonicDecoderModel(Module):
             state_bag=state_bag,
         )
 
-    @finaloverride
+    @override
     def project(self, decoder_output: Tensor) -> Tensor:
         logits = self.final_proj(decoder_output)
 

+ 2 - 2
src/seamless_communication/models/monotonic_decoder/monotonic_decoder.py

@@ -16,7 +16,7 @@ from fairseq2.nn.transformer import (
     CausalAttentionMaskFactory,
     create_standard_layer_norm,
 )
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch.nn import Module
 
@@ -62,7 +62,7 @@ class MonotonicTransformerDecoder(Module):
             self.model_dim, device=device, dtype=dtype
         )
 
-    @finaloverride
+    @override
     def forward(
         self,
         seqs: Tensor,

+ 2 - 2
src/seamless_communication/models/monotonic_decoder/monotonic_decoder_layer.py

@@ -15,7 +15,7 @@ from fairseq2.nn.transformer import (
     MultiheadAttention,
     create_standard_layer_norm,
 )
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch.nn import Dropout, Module
 
@@ -104,7 +104,7 @@ class MonotonicTransformerDecoderLayer(Module):
         else:
             self.register_module("ffn_dropout", None)
 
-    @finaloverride
+    @override
     def forward(
         self,
         seqs: Tensor,

+ 2 - 2
src/seamless_communication/models/monotonic_decoder/p_choose.py

@@ -8,7 +8,7 @@ from typing import Optional, final
 
 import torch
 from fairseq2.nn.projection import Linear
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch.nn import AvgPool1d, Module, ModuleList, ReLU
 from torch.nn.parameter import Parameter
@@ -116,7 +116,7 @@ class PChooseLayer(Module):
             ceil_mode=True,
         )
 
-    @finaloverride
+    @override
     def forward(self, seqs: Tensor, keys: Tensor) -> Tensor:
         q = self.q_energy_proj(seqs)
 

+ 2 - 2
src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py

@@ -7,7 +7,7 @@
 from dataclasses import dataclass
 from typing import List, Optional
 
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.typing import DataType, Device
 
 from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
@@ -27,7 +27,7 @@ class EcapaTDNNConfig:
     input_dim: int
 
 
-ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+ecapa_tdnn_archs = ModelArchitectureRegistry[EcapaTDNNConfig]()
 
 ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
 

+ 12 - 29
src/seamless_communication/models/tokenizer.py

@@ -4,32 +4,26 @@
 # This source code is licensed under the BSD-style license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
+from pathlib import Path
 from typing import Optional, Sequence, Set, final
 
 from fairseq2.data.text import (
-    SentencePieceDecoder,
+    SentencePieceTokenizer,
     SentencePieceEncoder,
-    SentencePieceModel,
-    TextTokenDecoder,
-    TextTokenEncoder,
-    TextTokenizer,
-    vocab_info_from_sentencepiece,
 )
-from fairseq2.data.typing import PathLike
-from fairseq2.typing import Device, finaloverride
+from fairseq2.typing import Device, override
 
 
 @final
-class SPMTokenizer(TextTokenizer):
+class SPMTokenizer(SentencePieceTokenizer):
     """Represents standard SPM-based tokenizer used in MT tasks"""
 
-    model: SentencePieceModel
     langs: Set[str]
     prepend_target_langtok_to_target: bool
 
     def __init__(
         self,
-        pathname: PathLike,
+        path: Path,
         langs: Sequence[str],
         prepend_target_langtok_to_target: bool = True,
     ) -> None:
@@ -41,20 +35,19 @@ class SPMTokenizer(TextTokenizer):
         :param default_lang:
             The fall-back language if no language is specified.
         """
-        self.langs = set(langs)
-        self.prepend_target_langtok_to_target = prepend_target_langtok_to_target
-
         # Each language is represented by a `__lang__` control symbol.
         control_symbols = [self._lang_tok_to_internal(lang) for lang in sorted(langs)]
-        self.model = SentencePieceModel(pathname, control_symbols)
-        vocab_info = vocab_info_from_sentencepiece(self.model)
-        super().__init__(vocab_info)
+
+        super().__init__(path, control_symbols)
+
+        self.langs = set(langs)
+        self.prepend_target_langtok_to_target = prepend_target_langtok_to_target
 
     @classmethod
     def _lang_tok_to_internal(cls, lang: str) -> str:
         return f"__{lang}__"
 
-    @finaloverride
+    @override
     def create_encoder(
         self,
         *,
@@ -63,7 +56,7 @@ class SPMTokenizer(TextTokenizer):
         mode: Optional[str] = None,
         device: Optional[Device] = None,
         pin_memory: bool = False,
-    ) -> TextTokenEncoder:
+    ) -> SentencePieceEncoder:
         """Create a token encoder.
 
         :param task:
@@ -110,13 +103,3 @@ class SPMTokenizer(TextTokenizer):
             device=device,
             pin_memory=pin_memory,
         )
-
-    @finaloverride
-    def create_raw_encoder(
-        self, *, device: Optional[Device] = None, pin_memory: bool = False
-    ) -> TextTokenEncoder:
-        return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)
-
-    @finaloverride
-    def create_decoder(self) -> TextTokenDecoder:
-        return SentencePieceDecoder(self.model)

+ 2 - 0
src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py

@@ -63,8 +63,10 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:
         final_proj_bias=True,
         temporal_mask_span_len=10,
         max_temporal_mask_prob=0.65,
+        min_num_temporal_mask_spans=2,
         spatial_mask_span_len=10,
         max_spatial_mask_prob=0.0,
+        min_num_spatial_mask_spans=2,
         quantized_dim=1024,
         num_codebooks=2,
         num_codebook_entries=320,

+ 62 - 28
src/seamless_communication/models/unity/builder.py

@@ -7,9 +7,10 @@
 from dataclasses import dataclass
 from typing import Optional, Union
 
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
 from fairseq2.nn.projection import TiedProjection
@@ -36,7 +37,7 @@ from seamless_communication.models.unity.adaptor_block import (
     UnitYEncoderAdaptor,
     UnitYTransformerAdaptorLayer,
 )
-from seamless_communication.models.unity.model import UnitYModel
+from seamless_communication.models.unity.model import UNITY_FAMILY, UnitYModel
 from seamless_communication.models.unity.t2u_builder import (
     UnitYNART2UBuilder,
     UnitYT2UBuilder,
@@ -100,7 +101,7 @@ class UnitYConfig:
     """The dropout probability in Transformer layers of the adaptor block."""
 
 
-unity_archs = ArchitectureRegistry[UnitYConfig]("unity")
+unity_archs = ModelArchitectureRegistry[UnitYConfig]()
 
 unity_arch = unity_archs.decorator
 
@@ -111,7 +112,15 @@ def _base() -> UnitYConfig:
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
-    mt_model_config.vocab_info.size = 256102  # NLLB-100
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256102,  # NLLB-100
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
     t2u_config = unity_t2u_archs.get_config("base")
 
@@ -139,7 +148,15 @@ def _medium() -> UnitYConfig:
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_600m")
 
-    mt_model_config.vocab_info.size = 256206  # NLLB-200
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256206,  # NLLB-200
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
     t2u_config = unity_t2u_archs.get_config("medium")
 
@@ -163,11 +180,19 @@ def _medium() -> UnitYConfig:
 
 @unity_arch("base_v2")
 def _base_v2() -> UnitYConfig:
-    conformer_shaw_encoder_config = conformer_shaw_archs.get_config("600m")
+    conformer_shaw_config = conformer_shaw_archs.get_config("conformer_shaw_600m")
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
-    mt_model_config.vocab_info.size = 256102  # NLLB-100
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256102,  # NLLB-100
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
     mt_model_config.max_seq_len = 4096
 
@@ -175,7 +200,7 @@ def _base_v2() -> UnitYConfig:
 
     return UnitYConfig(
         model_dim=1024,
-        w2v2_encoder_config=conformer_shaw_encoder_config,
+        w2v2_encoder_config=conformer_shaw_config.encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         prosody_encoder_config=None,
@@ -193,11 +218,19 @@ def _base_v2() -> UnitYConfig:
 
 @unity_arch("expressivity_v2")
 def _expressivity_v2() -> UnitYConfig:
-    conformer_shaw_encoder_config = conformer_shaw_archs.get_config("600m")
+    conformer_shaw_config = conformer_shaw_archs.get_config("conformer_shaw_600m")
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
-    mt_model_config.vocab_info.size = 256102  # NLLB-100
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256102,  # NLLB-100
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
     mt_model_config.max_seq_len = 10000
 
@@ -207,7 +240,7 @@ def _expressivity_v2() -> UnitYConfig:
 
     return UnitYConfig(
         model_dim=1024,
-        w2v2_encoder_config=conformer_shaw_encoder_config,
+        w2v2_encoder_config=conformer_shaw_config.encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         prosody_encoder_config=prosody_encoder_config,
@@ -263,19 +296,19 @@ class UnitYBuilder:
         :param dtype:
             The data type of module parameters and buffers.
         """
-        if w2v2_encoder_builder.config.model_dim != config.model_dim:
+        if config.w2v2_encoder_config.model_dim != config.model_dim:
             raise ValueError(
-                f"`model_dim` and `model_dim` of `w2v2_encoder_builder.config` must be equal, but are {config.model_dim} and {w2v2_encoder_builder.config.model_dim} instead."
+                f"`config.model_dim` and `config.w2v2_encoder_config.model_dim` must be equal, but are {config.model_dim} and {config.w2v2_encoder_config.model_dim} instead."
             )
 
-        if mt_model_builder.config.model_dim != config.model_dim:
+        if config.mt_model_config.model_dim != config.model_dim:
             raise ValueError(
-                f"`model_dim` and `model_dim` of `mt_model_builder.config` must be equal, but are {config.model_dim} and {mt_model_builder.config.model_dim} instead."
+                f"`config.model_dim` and `config.mt_model_config.model_dim` must be equal, but are {config.model_dim} and {config.mt_model_config.model_dim} instead."
             )
 
-        if t2u_builder is not None and t2u_builder.config.model_dim != config.model_dim:
+        if config.t2u_config is not None and config.t2u_config.model_dim != config.model_dim:
             raise ValueError(
-                f"`model_dim` and `model_dim` of `t2u_builder.config` must be equal, but are {config.model_dim} and {t2u_builder.config.model_dim} instead."
+                f"`config.model_dim` and `config.t2u_config.model_dim` must be equal, but are {config.model_dim} and {config.t2u_config.model_dim} instead."
             )
 
         self.config = config
@@ -337,6 +370,7 @@ class UnitYBuilder:
             text_decoder,
             final_proj,
             t2u_model,
+            self.config.mt_model_config.max_seq_len or 0,
             self.config.mt_model_config.vocab_info,
             prosody_encoder_model,
         )
@@ -367,12 +401,12 @@ class UnitYBuilder:
     def build_adaptor_layer(self, idx: int) -> TransformerEncoderLayer:
         """Build a Transformer-based encoder adaptor layer."""
         self_attn = self.build_adaptor_attention(
-            self.w2v2_encoder_builder.config.num_encoder_attn_heads
+            self.config.w2v2_encoder_config.num_encoder_attn_heads
         )
 
         ffn = StandardFeedForwardNetwork(
             self.config.model_dim,
-            self.w2v2_encoder_builder.config.ffn_inner_dim,
+            self.config.w2v2_encoder_config.ffn_inner_dim,
             inner_activation=GELU() if self.config.use_gelu else ReLU(),
             bias=True,
             device=self.device,
@@ -396,12 +430,12 @@ class UnitYBuilder:
         # Empirically shown that, in adaptor layers, vanilla MHA performs better
         # than MHA with relative positional encoding.
         self_attn = self.build_adaptor_attention(
-            self.w2v2_encoder_builder.config.num_encoder_attn_heads
+            self.config.w2v2_encoder_config.num_encoder_attn_heads
         )
 
         conv = ConformerConvolution(
-            self.w2v2_encoder_builder.config.model_dim,
-            self.w2v2_encoder_builder.config.depthwise_conv_kernel_size,
+            self.config.w2v2_encoder_config.model_dim,
+            self.config.w2v2_encoder_config.depthwise_conv_kernel_size,
             device=self.device,
             dtype=self.dtype,
         )
@@ -446,13 +480,13 @@ class NllbWithGELUBuilder(NllbBuilder):
     @override
     def build_ffn(self) -> FeedForwardNetwork:
         return StandardFeedForwardNetwork(
-            self.config.model_dim,
-            self.config.ffn_inner_dim,
+            self._config.model_dim,
+            self._config.ffn_inner_dim,
             bias=True,
             inner_activation=GELU(),
             norm_order=TransformerNormOrder.PRE,
-            device=self.device,
-            dtype=self.dtype,
+            device=self._device,
+            dtype=self._dtype,
         )
 
 
@@ -497,11 +531,11 @@ def create_unity_model(
 
     if config.use_gelu:
         mt_model_builder: NllbBuilder = NllbWithGELUBuilder(
-            config.mt_model_config, device=device, dtype=dtype
+            UNITY_FAMILY, config.mt_model_config, device=device, dtype=dtype
         )
     else:
         mt_model_builder = NllbBuilder(
-            config.mt_model_config, device=device, dtype=dtype
+            UNITY_FAMILY, config.mt_model_config, device=device, dtype=dtype
         )
 
     unity_builder = UnitYBuilder(

+ 8 - 33
src/seamless_communication/models/unity/char_tokenizer.py

@@ -4,6 +4,7 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
+from pathlib import Path
 from typing import Optional, Union, final
 
 from fairseq2.assets import (
@@ -14,36 +15,24 @@ from fairseq2.assets import (
 )
 from fairseq2.assets.card import AssetCard
 from fairseq2.data.text import (
-    SentencePieceDecoder,
+    SentencePieceTokenizer,
     SentencePieceEncoder,
-    SentencePieceModel,
-    TextTokenDecoder,
-    TextTokenEncoder,
-    TextTokenizer,
-    vocab_info_from_sentencepiece,
 )
-from fairseq2.data.typing import PathLike
-from fairseq2.typing import Device, finaloverride
+from fairseq2.typing import Device, override
 
 
 @final
-class CharTokenizer(TextTokenizer):
+class CharTokenizer(SentencePieceTokenizer):
     """A character-level tokenizer used during non-autoregressive T2U decoding."""
 
-    model: SentencePieceModel
-
-    def __init__(self, pathname: PathLike) -> None:
+    def __init__(self, path: Path) -> None:
         """
         :param pathname:
             The pathname of the SentencePiece model file.
         """
-        self.model = SentencePieceModel(pathname)
-
-        vocab_info = vocab_info_from_sentencepiece(self.model)
-
-        super().__init__(vocab_info)
+        super().__init__(path)
 
-    @finaloverride
+    @override
     def create_encoder(
         self,
         task: Optional[str] = None,
@@ -51,24 +40,10 @@ class CharTokenizer(TextTokenizer):
         mode: Optional[str] = None,
         device: Optional[Device] = None,
         pin_memory: bool = False,
-    ) -> TextTokenEncoder:
+    ) -> SentencePieceEncoder:
         """Creates a character level encoder."""
-        return SentencePieceEncoder(
-            self.model,
-            device=device,
-            pin_memory=pin_memory,
-        )
-
-    @finaloverride
-    def create_raw_encoder(
-        self, *, device: Optional[Device] = None, pin_memory: bool = False
-    ) -> TextTokenEncoder:
         return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)
 
-    @finaloverride
-    def create_decoder(self) -> TextTokenDecoder:
-        return SentencePieceDecoder(self.model)
-
 
 class UnitYCharTokenizerLoader:
     """Loads character-level tokenizers of UnitY models."""

+ 2 - 2
src/seamless_communication/models/unity/fft_decoder.py

@@ -10,7 +10,7 @@ from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.transformer import TransformerNormOrder, create_standard_layer_norm
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch.nn import Module
 
@@ -61,7 +61,7 @@ class FeedForwardTransformer(Module):
 
         self.norm_order = norm_order
 
-    @finaloverride
+    @override
     def forward(
         self,
         seqs: Tensor,

+ 3 - 3
src/seamless_communication/models/unity/fft_decoder_layer.py

@@ -9,7 +9,7 @@ from typing import Optional, Tuple, final
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch.nn import Conv1d, Dropout, Module, ReLU
 
@@ -71,7 +71,7 @@ class Conv1dBlock(Module):
             dtype=dtype,
         )
 
-    @finaloverride
+    @override
     def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
         # Ensure that we do not leak padded positions in the convolution layer.
         seqs = apply_padding_mask(seqs, padding_mask)
@@ -173,7 +173,7 @@ class FeedForwardTransformerLayer(Module):
         else:
             self.register_module("film", None)
 
-    @finaloverride
+    @override
     def forward(
         self,
         seqs: Tensor,

+ 78 - 76
src/seamless_communication/models/unity/loader.py

@@ -4,14 +4,14 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, List, Mapping, Tuple, Union
+from typing import Any, Dict, List, Tuple, Union
 
 import torch
-from fairseq2.assets import AssetStore, asset_store, download_manager
+from fairseq2.assets import AssetStore, asset_store
 from fairseq2.assets.card import AssetCard, AssetCardFieldNotFoundError
-from fairseq2.models.nllb import NllbConfig
-from fairseq2.models.nllb.loader import NllbTokenizerLoader
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models.nllb import NllbConfig, load_nllb_tokenizer
+from fairseq2.models import setup_model_family
+from fairseq2.data.text import register_text_tokenizer
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 
 from seamless_communication.models.unity.builder import (
@@ -20,13 +20,13 @@ from seamless_communication.models.unity.builder import (
     unity_archs,
 )
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
-from seamless_communication.models.unity.model import UnitYModel
+from seamless_communication.models.unity.model import UNITY_FAMILY
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 
 
 def convert_unity_checkpoint(
-    checkpoint: Mapping[str, Any], config: UnitYConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: UnitYConfig
+) -> Dict[str, Any]:
     state_dict = checkpoint["model"]
 
     # Check if we have a fairseq2 checkpoint.
@@ -39,7 +39,11 @@ def convert_unity_checkpoint(
 
     state_dict = checkpoint["model"]
 
-    keys_to_delete = []
+    keys_to_delete = [
+        "speech_encoder_frontend.pos_encoder.conv.bias",
+        "speech_encoder_frontend.pos_encoder.conv.weight_g",
+        "speech_encoder_frontend.pos_encoder.conv.weight_v",
+    ]
 
     # ExpressiveUnitY model (from multi_arch codebase)
     if config.prosody_encoder_config is not None:
@@ -203,42 +207,42 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
         # fmt: off
 
         # Speech Encoder
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    r"speech_encoder_frontend.pos_encoder.conv.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.":                                              r"speech_encoder_frontend.post_extract_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       r"speech_encoder_frontend.model_dim_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
-
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     r"speech_encoder.inner.layers.\1.conv.layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.inner.layers.\1.layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     r"speech_encoder.inner.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    "speech_encoder_frontend.pos_encoder.conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.":                                              "speech_encoder_frontend.post_extract_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       "speech_encoder_frontend.model_dim_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             "speech_encoder_frontend.feature_extractor.layers.\\1.conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          "speech_encoder_frontend.feature_extractor.layers.\\1.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    "speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
+
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      "speech_encoder.inner.layers.\\1.conv.batch_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     "speech_encoder.inner.layers.\\1.conv.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  "speech_encoder.inner.layers.\\1.conv.depthwise_conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      "speech_encoder.inner.layers.\\1.conv_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": "speech_encoder.inner.layers.\\1.conv.pointwise_conv1.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": "speech_encoder.inner.layers.\\1.conv.pointwise_conv2.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         "speech_encoder.inner.layers.\\1.ffn\\2_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                "speech_encoder.inner.layers.\\1.ffn\\2.inner_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                "speech_encoder.inner.layers.\\1.ffn\\2.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         "speech_encoder.inner.layers.\\1.self_attn_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          "speech_encoder.inner.layers.\\1.self_attn.q_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          "speech_encoder.inner.layers.\\1.self_attn.k_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          "speech_encoder.inner.layers.\\1.self_attn.v_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        "speech_encoder.inner.layers.\\1.self_attn.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            "speech_encoder.inner.layers.\\1.self_attn.q_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            "speech_encoder.inner.layers.\\1.self_attn.k_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            "speech_encoder.inner.layers.\\1.self_attn.v_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   "speech_encoder.inner.layers.\\1.self_attn.sdpa.rel_k_embed.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          "speech_encoder.inner.layers.\\1.self_attn.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        "speech_encoder.inner.layers.\\1.self_attn.sdpa.r_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          "speech_encoder.inner.layers.\\1.self_attn.sdpa.u_bias",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          "speech_encoder.inner.layers.\\1.self_attn.sdpa.v_bias",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             "speech_encoder.inner.layers.\\1.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     "speech_encoder.inner.layer_norm.",
 
         # Speech Encoder Adaptor
-        fr"^{encoder_key}\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
-        fr"^{encoder_key}\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
-        fr"^{encoder_key}\.adaptor\.out_ln\.":  r"speech_encoder.layer_norm.",
+        fr"^{encoder_key}\.adaptor\.proj\.0\.": "speech_encoder.proj1.",
+        fr"^{encoder_key}\.adaptor\.proj\.2\.": "speech_encoder.proj2.",
+        fr"^{encoder_key}\.adaptor\.out_ln\.":  "speech_encoder.layer_norm.",
 
         # Text Encoder
         r"^text_encoder\.embed_tokens\.":                              r"text_encoder_frontend.embed.",
@@ -264,13 +268,13 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
     if config.w2v2_encoder_config.use_conformer:
         key_map.update(
             {
-                fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
+                fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": "speech_encoder.inner_layer_norm."
             }
         )
     else:
         key_map.update(
             {
-                rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
+                rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": "speech_encoder.inner.layer_norm."
             }
         )
     # fmt: on
@@ -279,20 +283,20 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
         key_map.update(
             {
                 # fmt: off
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    r"speech_encoder.adaptor_layers.\1.block.self_attn.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      r"speech_encoder.adaptor_layers.\1.layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 r"speech_encoder.adaptor_layers.\1.conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          "speech_encoder.adaptor_layers.\\1.block.self_attn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    "speech_encoder.adaptor_layers.\\1.block.self_attn.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         "speech_encoder.adaptor_layers.\\1.block.self_attn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         "speech_encoder.adaptor_layers.\\1.block.ffn\\2_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                "speech_encoder.adaptor_layers.\\1.block.ffn\\2.inner_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                "speech_encoder.adaptor_layers.\\1.block.ffn\\2.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      "speech_encoder.adaptor_layers.\\1.block.conv.batch_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  "speech_encoder.adaptor_layers.\\1.block.conv.depthwise_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      "speech_encoder.adaptor_layers.\\1.block.conv_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": "speech_encoder.adaptor_layers.\\1.block.conv.pointwise_conv1.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": "speech_encoder.adaptor_layers.\\1.block.conv.pointwise_conv2.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             "speech_encoder.adaptor_layers.\\1.block.layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      "speech_encoder.adaptor_layers.\\1.layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 "speech_encoder.adaptor_layers.\\1.conv.",
                 # fmt: on
             }
         )
@@ -300,15 +304,15 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
         key_map.update(
             {
                 # fmt: off
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     r"speech_encoder.adaptor_layers.\1.residual_conv.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":            r"speech_encoder.adaptor_layers.\1.self_attn.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.":                  r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.":                  r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  "speech_encoder.adaptor_layers.\\1.residual_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     "speech_encoder.adaptor_layers.\\1.residual_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         "speech_encoder.adaptor_layers.\\1.self_attn_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  "speech_encoder.adaptor_layers.\\1.self_attn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":            "speech_encoder.adaptor_layers.\\1.self_attn.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": "speech_encoder.adaptor_layers.\\1.self_attn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.":                  "speech_encoder.adaptor_layers.\\1.ffn.inner_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.":                  "speech_encoder.adaptor_layers.\\1.ffn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     "speech_encoder.adaptor_layers.\\1.ffn_layer_norm.",
                 # fmt: on
             }
         )
@@ -389,20 +393,18 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
     return key_map
 
 
-load_unity_config = ConfigLoader[UnitYConfig](asset_store, unity_archs)
-
-
-load_unity_model = ModelLoader[UnitYModel, UnitYConfig](
-    asset_store,
-    download_manager,
-    load_unity_config,
+load_unity_model, load_unity_config = setup_model_family(
+    UNITY_FAMILY,
+    UnitYConfig,
     create_unity_model,
+    unity_archs,
     convert_unity_checkpoint,
     restrict_checkpoints=False,
 )
 
+load_unity_text_tokenizer = load_nllb_tokenizer
 
-load_unity_text_tokenizer = NllbTokenizerLoader(asset_store, download_manager)
+register_text_tokenizer(UNITY_FAMILY, load_unity_text_tokenizer)
 
 
 class UnitYUnitTokenizerLoader:

+ 13 - 8
src/seamless_communication/models/unity/model.py

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Optional, Tuple, Union, final
+from typing import Final, Optional, Tuple, Union, final
 
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.encoder_decoder import EncoderDecoderModel
@@ -23,6 +23,8 @@ from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 
+UNITY_FAMILY: Final = "unity"
+
 
 @final
 class UnitYModel(EncoderDecoderModel):
@@ -55,13 +57,14 @@ class UnitYModel(EncoderDecoderModel):
         text_decoder: Optional[TransformerDecoder],
         final_proj: Optional[Projection],
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
+        max_target_seq_len: int,
         target_vocab_info: VocabularyInfo,
         prosody_encoder_model: Optional[ECAPA_TDNN] = None,
         input_modality: str = "speech",
     ) -> None:
         model_dim = speech_encoder.model_dim
 
-        super().__init__(model_dim, target_vocab_info)
+        super().__init__(UNITY_FAMILY, model_dim, max_target_seq_len, target_vocab_info)
 
         self.input_modality = input_modality
 
@@ -190,7 +193,7 @@ class UnitYModel(EncoderDecoderModel):
 
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 @final
@@ -209,11 +212,12 @@ class UnitYX2TModel(EncoderDecoderModel):
         decoder_frontend: TransformerFrontend,
         decoder: TransformerDecoder,
         final_proj: Projection,
+        max_target_seq_len: int,
         target_vocab_info: VocabularyInfo,
     ) -> None:
         model_dim = encoder.model_dim
 
-        super().__init__(model_dim, target_vocab_info)
+        super().__init__(UNITY_FAMILY, model_dim, max_target_seq_len, target_vocab_info)
 
         self.encoder_frontend = encoder_frontend
         self.encoder = encoder
@@ -257,7 +261,7 @@ class UnitYX2TModel(EncoderDecoderModel):
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 @final
@@ -276,9 +280,10 @@ class UnitYT2UModel(EncoderDecoderModel):
         decoder_frontend: TransformerFrontend,
         decoder: TransformerDecoder,
         final_proj: Projection,
+        max_target_seq_len: int,
         target_vocab_info: VocabularyInfo,
     ) -> None:
-        super().__init__(decoder.model_dim, target_vocab_info)
+        super().__init__(UNITY_FAMILY, decoder.model_dim, max_target_seq_len, target_vocab_info)
 
         if encoder is not None:
             self.encoder = encoder
@@ -324,7 +329,7 @@ class UnitYT2UModel(EncoderDecoderModel):
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 @final
@@ -438,7 +443,7 @@ class UnitYNART2UModel(Module):
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 @dataclass

+ 2 - 2
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -15,7 +15,7 @@ from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder
 from fairseq2.nn.transformer import create_standard_layer_norm
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch.nn import Dropout, Module, Parameter
 
@@ -296,7 +296,7 @@ class NARDecoderFrontend(Module):
 
         return seqs
 
-    @finaloverride
+    @override
     def forward(
         self,
         encoder_output: Tensor,

+ 5 - 6
src/seamless_communication/models/unity/t2u_builder.py

@@ -6,15 +6,14 @@
 from dataclasses import dataclass
 from typing import Literal, Optional, Union
 
-from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import VocabularyInfo
-from fairseq2.models.nllb.loader import NllbTokenizerLoader
+from fairseq2.models.nllb import load_nllb_tokenizer
 from fairseq2.models.transformer import (
     TransformerEmbeddingFrontend,
     TransformerFrontend,
 )
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import Linear, Projection, TiedProjection
@@ -131,8 +130,7 @@ class UnitYT2UConfig:
     """The dimensionality of prosody encoder (e.g. ECAPA_TDNN) output"""
 
 
-unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
-
+unity_t2u_archs = ModelArchitectureRegistry[UnitYT2UConfig]()
 
 unity_t2u_arch = unity_t2u_archs.decorator
 
@@ -329,6 +327,7 @@ class UnitYT2UBuilder:
             decoder_frontend,
             decoder,
             final_proj,
+            self.config.unit_max_seq_len,
             self.config.target_vocab_info,
         )
 
@@ -598,7 +597,7 @@ class UnitYNART2UBuilder:
             self.config.nar_decoder_frontend_config
         )
 
-        nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
+        nllb_tokenizer = load_nllb_tokenizer(
             self.config.nar_decoder_config.model_name_or_card
         )
 

+ 2 - 2
src/seamless_communication/models/vocoder/builder.py

@@ -7,7 +7,7 @@
 from dataclasses import dataclass
 from typing import Any, Dict, List, Optional
 
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.typing import DataType, Device
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
@@ -34,7 +34,7 @@ class VocoderConfig:
     lang_spkr_idx_map: Dict[str, Any]
 
 
-vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_code_hifigan")
+vocoder_archs = ModelArchitectureRegistry[VocoderConfig]()
 
 vocoder_arch = vocoder_archs.decorator
 

+ 9 - 13
src/seamless_communication/models/vocoder/loader.py

@@ -4,22 +4,21 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, Mapping
+from typing import Any, Dict
 
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 
 from seamless_communication.models.vocoder.builder import (
     VocoderConfig,
     create_vocoder_model,
     vocoder_archs,
 )
-from seamless_communication.models.vocoder.vocoder import Vocoder
+from seamless_communication.models.vocoder.vocoder import VOCODER_CODE_HIFIGAN_FAMILY
 
 
 def convert_vocoder_checkpoint(
-    checkpoint: Mapping[str, Any], config: VocoderConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: VocoderConfig
+) -> Dict[str, Any]:
     if (
         "model" in checkpoint
         and "code_generator.resblocks.0.convs1.0.weight_g" in checkpoint["model"]
@@ -36,13 +35,10 @@ def convert_vocoder_checkpoint(
     return checkpoint
 
 
-load_vocoder_config = ConfigLoader[VocoderConfig](asset_store, vocoder_archs)
-
-
-load_vocoder_model = ModelLoader[Vocoder, VocoderConfig](
-    asset_store,
-    download_manager,
-    load_vocoder_config,
+load_vocoder_model, load_vocoder_config = setup_model_family(
+    VOCODER_CODE_HIFIGAN_FAMILY,
+    VocoderConfig,
     create_vocoder_model,
+    vocoder_archs,
     convert_vocoder_checkpoint,
 )

+ 6 - 5
src/seamless_communication/models/vocoder/vocoder.py

@@ -4,21 +4,22 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, Optional, List, Union
+from typing import Any, Dict, Optional, Final, List, Union
 import torch
 from torch import Tensor
-from torch.nn import Module
+from fairseq2.models import Model
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 
+VOCODER_CODE_HIFIGAN_FAMILY: Final = "vocoder_code_hifigan"
 
-class Vocoder(Module):
+class Vocoder(Model):
     def __init__(
         self,
         code_generator: CodeGenerator,
         lang_spkr_idx_map: Dict[str, Any],
     ):
-        super().__init__()
+        super().__init__(VOCODER_CODE_HIFIGAN_FAMILY)
         self.code_generator = code_generator
         self.lang_spkr_idx_map = lang_spkr_idx_map
 
@@ -29,7 +30,7 @@ class Vocoder(Module):
         spkr_list: Union[Optional[List[int]], int] = None,
         dur_prediction: bool = True,
     ) -> Tensor:
-        # TODO: Do we need this backward compatibility, or just update all calling sites? 
+        # TODO: Do we need this backward compatibility, or just update all calling sites?
         if len(units.shape) == 1:
             units = units.unsqueeze(0) # add batch dim
         if isinstance(lang_list, str):

+ 1 - 2
src/seamless_communication/toxicity/etox_bad_word_checker.py

@@ -16,7 +16,6 @@ from fairseq2.assets import (
     asset_store as base_asset_store,
     download_manager as base_download_manager,
 )
-from fairseq2.data import StringLike
 from fairseq2.data.text import SentencePieceEncoder, SentencePieceModel
 
 
@@ -116,7 +115,7 @@ class ETOXBadWordChecker:
 
     @staticmethod
     def _contains_tokens(
-        text_tokens: List[StringLike], word_tokens: List[StringLike]
+        text_tokens: List[str], word_tokens: List[str]
     ) -> bool:
         for i in range(len(text_tokens) - len(word_tokens) + 1):
             for j in range(len(word_tokens)):

+ 7 - 8
src/seamless_communication/toxicity/mintox.py

@@ -18,7 +18,6 @@ from seamless_communication.toxicity.etox_bad_word_checker import (
 )
 from fairseq2.generation import BannedSequenceProcessor
 from fairseq2.data.text.text_tokenizer import TextTokenizer
-from fairseq2.data.typing import StringLike
 from fairseq2.typing import Device
 from fairseq2.data import SequenceData
 from fairseq2.nn.padding import get_seqs_and_padding_mask
@@ -32,8 +31,8 @@ logger = logging.getLogger(__name__)
 
 
 def _extract_bad_words_with_batch_indices(
-    source_texts: List[StringLike],
-    target_texts: List[StringLike],
+    source_texts: List[str],
+    target_texts: List[str],
     source_lang: str,
     target_lang: str,
     bad_word_checker: ETOXBadWordChecker,
@@ -54,9 +53,9 @@ def _extract_bad_words_with_batch_indices(
 
 
 def _replace_with_new_text_output_in_batch(
-    original_texts: List[StringLike],
+    original_texts: List[str],
     indices_with_toxicity: List[int],
-    new_texts: List[StringLike],
+    new_texts: List[str],
 ) -> None:
     new_idx = 0
     # indices_with_toxicity is a small list, using list should be fast enough.
@@ -100,8 +99,8 @@ def mintox_pipeline(
     model_input: SequenceData,
     input_modality: "Modality",
     output_modality: "Modality",
-    src_texts: List[StringLike],
-    original_texts: List[StringLike],
+    src_texts: List[str],
+    original_texts: List[str],
     original_units: Optional[Tensor] = None,
     unit_generation_ngram_filtering: bool = False,
     text_generation_opts: Optional[SequenceGeneratorOptions] = None,
@@ -109,7 +108,7 @@ def mintox_pipeline(
     bad_word_checker: ETOXBadWordChecker = None,
     duration_factor: float = 1.0,
     prosody_encoder_input: Optional[SequenceData] = None,
-) -> Tuple[List[StringLike], Optional[Tensor]]:
+) -> Tuple[List[str], Optional[Tensor]]:
     """MinTox: Mitigation at INference time of added TOXicity."""
     from seamless_communication.inference.translator import Modality, Translator