Selaa lähdekoodia

Migrate to fairseq2 v0.3

Can Balioglu 1 vuosi sitten
vanhempi
commit
283f74250f
51 muutettua tiedostoa jossa 389 lisäystä ja 374 poistoa
  1. 1 1
      setup.py
  2. 1 1
      src/seamless_communication/cards/conformer_shaw.yaml
  3. 1 1
      src/seamless_communication/cards/nar_t2u_aligner.yaml
  4. 1 1
      src/seamless_communication/cards/seamless_streaming_monotonic_decoder.yaml
  5. 1 1
      src/seamless_communication/cards/unity_nllb-100.yaml
  6. 1 1
      src/seamless_communication/cards/unity_nllb-200.yaml
  7. 1 1
      src/seamless_communication/cards/vocoder_36langs.yaml
  8. 1 1
      src/seamless_communication/cards/vocoder_pretssel.yaml
  9. 1 1
      src/seamless_communication/cards/vocoder_pretssel_16khz.yaml
  10. 1 1
      src/seamless_communication/cards/vocoder_v2.yaml
  11. 1 1
      src/seamless_communication/cards/xlsr2_1b_v2.yaml
  12. 1 1
      src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py
  13. 3 4
      src/seamless_communication/cli/m4t/evaluate/evaluate.py
  14. 1 1
      src/seamless_communication/cli/m4t/finetune/dataloader.py
  15. 8 8
      src/seamless_communication/cli/m4t/finetune/trainer.py
  16. 1 1
      src/seamless_communication/cli/toxicity/asr_etox.py
  17. 5 3
      src/seamless_communication/inference/generator.py
  18. 4 4
      src/seamless_communication/inference/translator.py
  19. 4 5
      src/seamless_communication/models/aligner/alignment_extractor.py
  20. 8 7
      src/seamless_communication/models/aligner/builder.py
  21. 9 14
      src/seamless_communication/models/aligner/loader.py
  22. 29 10
      src/seamless_communication/models/aligner/model.py
  23. 34 19
      src/seamless_communication/models/conformer_shaw/builder.py
  24. 12 12
      src/seamless_communication/models/conformer_shaw/loader.py
  25. 2 3
      src/seamless_communication/models/generator/builder.py
  26. 2 2
      src/seamless_communication/models/generator/ecapa_tdnn_builder.py
  27. 6 12
      src/seamless_communication/models/generator/loader.py
  28. 1 1
      src/seamless_communication/models/generator/streamable.py
  29. 7 3
      src/seamless_communication/models/generator/vocoder.py
  30. 2 4
      src/seamless_communication/models/monotonic_decoder/builder.py
  31. 9 17
      src/seamless_communication/models/monotonic_decoder/loader.py
  32. 10 7
      src/seamless_communication/models/monotonic_decoder/model.py
  33. 2 2
      src/seamless_communication/models/monotonic_decoder/monotonic_decoder.py
  34. 2 2
      src/seamless_communication/models/monotonic_decoder/monotonic_decoder_layer.py
  35. 2 2
      src/seamless_communication/models/monotonic_decoder/p_choose.py
  36. 2 2
      src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py
  37. 12 29
      src/seamless_communication/models/tokenizer.py
  38. 2 0
      src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py
  39. 62 28
      src/seamless_communication/models/unity/builder.py
  40. 8 33
      src/seamless_communication/models/unity/char_tokenizer.py
  41. 2 2
      src/seamless_communication/models/unity/fft_decoder.py
  42. 3 3
      src/seamless_communication/models/unity/fft_decoder_layer.py
  43. 78 76
      src/seamless_communication/models/unity/loader.py
  44. 13 8
      src/seamless_communication/models/unity/model.py
  45. 2 2
      src/seamless_communication/models/unity/nar_decoder_frontend.py
  46. 5 6
      src/seamless_communication/models/unity/t2u_builder.py
  47. 2 2
      src/seamless_communication/models/vocoder/builder.py
  48. 9 13
      src/seamless_communication/models/vocoder/loader.py
  49. 6 5
      src/seamless_communication/models/vocoder/vocoder.py
  50. 1 2
      src/seamless_communication/toxicity/etox_bad_word_checker.py
  51. 7 8
      src/seamless_communication/toxicity/mintox.py

+ 1 - 1
setup.py

@@ -22,7 +22,7 @@ setup(
     license="Creative Commons",
     license="Creative Commons",
     install_requires=[
     install_requires=[
         "datasets",
         "datasets",
-        "fairseq2==0.2.*",
+#        "fairseq2==0.2.*",
         "fire",
         "fire",
         "librosa",
         "librosa",
         "openai-whisper",
         "openai-whisper",

+ 1 - 1
src/seamless_communication/cards/conformer_shaw.yaml

@@ -5,6 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: conformer_shaw
 name: conformer_shaw
-model_type: wav2vec2
+model_family: conformer_shaw
 model_arch: conformer_shaw_600m
 model_arch: conformer_shaw_600m
 checkpoint: "https://huggingface.co/facebook/conformer-shaw/resolve/main/conformer_shaw.pt"
 checkpoint: "https://huggingface.co/facebook/conformer-shaw/resolve/main/conformer_shaw.pt"

+ 1 - 1
src/seamless_communication/cards/nar_t2u_aligner.yaml

@@ -6,7 +6,7 @@
 
 
 name: nar_t2u_aligner
 name: nar_t2u_aligner
 char_tokenizer: "https://huggingface.co/facebook/seamless-streaming/resolve/main/spm_char_lang38_tc.model"
 char_tokenizer: "https://huggingface.co/facebook/seamless-streaming/resolve/main/spm_char_lang38_tc.model"
-model_type: unity2_aligner
+model_family: unity2_aligner
 model_arch: nar_t2u_aligner
 model_arch: nar_t2u_aligner
 checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/unity2_aligner.pt"
 checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/unity2_aligner.pt"
 num_units: 10000
 num_units: 10000

+ 1 - 1
src/seamless_communication/cards/seamless_streaming_monotonic_decoder.yaml

@@ -5,6 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: seamless_streaming_monotonic_decoder
 name: seamless_streaming_monotonic_decoder
-model_type: monotonic_decoder
+model_family: monotonic_decoder
 model_arch: dense_1b
 model_arch: dense_1b
 checkpoint: "https://huggingface.co/facebook/seamless-streaming/resolve/main/seamless_streaming_monotonic_decoder.pt"
 checkpoint: "https://huggingface.co/facebook/seamless-streaming/resolve/main/seamless_streaming_monotonic_decoder.pt"

+ 1 - 1
src/seamless_communication/cards/unity_nllb-100.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: unity_nllb-100
 name: unity_nllb-100
-model_type: unity
+model_family: unity
 tokenizer: "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
 tokenizer: "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
 default_lang: eng
 default_lang: eng
 langs:
 langs:

+ 1 - 1
src/seamless_communication/cards/unity_nllb-200.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: unity_nllb-200
 name: unity_nllb-200
-model_type: unity
+model_family: unity
 tokenizer: "https://huggingface.co/facebook/seamless-m4t-medium/resolve/main/tokenizer.model"
 tokenizer: "https://huggingface.co/facebook/seamless-m4t-medium/resolve/main/tokenizer.model"
 default_lang: eng
 default_lang: eng
 langs:
 langs:

+ 1 - 1
src/seamless_communication/cards/vocoder_36langs.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: vocoder_36langs
 name: vocoder_36langs
-model_type: vocoder_code_hifigan
+model_family: vocoder_code_hifigan
 model_arch: base
 model_arch: base
 checkpoint: "https://huggingface.co/facebook/seamless-m4t-vocoder/resolve/main/vocoder_36langs.pt"
 checkpoint: "https://huggingface.co/facebook/seamless-m4t-vocoder/resolve/main/vocoder_36langs.pt"
 model_config: {
 model_config: {

+ 1 - 1
src/seamless_communication/cards/vocoder_pretssel.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: vocoder_pretssel
 name: vocoder_pretssel
-model_type: vocoder_pretssel
+model_family: vocoder_pretssel
 model_arch: 24khz
 model_arch: 24khz
 checkpoint: "https://github.com/facebookresearch/seamless_communication;gated=true"
 checkpoint: "https://github.com/facebookresearch/seamless_communication;gated=true"
 sample_rate: 24000
 sample_rate: 24000

+ 1 - 1
src/seamless_communication/cards/vocoder_pretssel_16khz.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: vocoder_pretssel_16khz
 name: vocoder_pretssel_16khz
-model_type: vocoder_pretssel
+model_family: vocoder_pretssel
 model_arch: 16khz
 model_arch: 16khz
 checkpoint: "https://github.com/facebookresearch/seamless_communication;gated=true"
 checkpoint: "https://github.com/facebookresearch/seamless_communication;gated=true"
 sample_rate: 16000
 sample_rate: 16000

+ 1 - 1
src/seamless_communication/cards/vocoder_v2.yaml

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: vocoder_v2
 name: vocoder_v2
-model_type: vocoder_code_hifigan
+model_family: vocoder_code_hifigan
 model_arch: base
 model_arch: base
 checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/vocoder_v2.pt"
 checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/vocoder_v2.pt"
 model_config: {
 model_config: {

+ 1 - 1
src/seamless_communication/cards/xlsr2_1b_v2.yaml

@@ -5,6 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 name: xlsr2_1b_v2
 name: xlsr2_1b_v2
-model_type: wav2vec2
+model_family: wav2vec2
 model_arch: xlsr2_1b_v2
 model_arch: xlsr2_1b_v2
 checkpoint: "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/xlsr2_1b_v2.pt"
 checkpoint: "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/xlsr2_1b_v2.pt"

+ 1 - 1
src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py

@@ -14,7 +14,7 @@ logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         description="Convert raw audio to units (and optionally audio) using UnitExtractor."
         description="Convert raw audio to units (and optionally audio) using UnitExtractor."
     )
     )

+ 3 - 4
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -19,7 +19,6 @@ import torchaudio
 from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
-from fairseq2.data.typing import StringLike
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 from torch import Tensor
 from tqdm import tqdm
 from tqdm import tqdm
@@ -181,10 +180,10 @@ def build_data_pipeline(
 
 
 def adjust_output_for_corrupted_inputs(
 def adjust_output_for_corrupted_inputs(
     valid_sequences: Tensor,
     valid_sequences: Tensor,
-    text_output: List[StringLike],
+    text_output: List[str],
     speech_output: Optional[BatchedSpeechOutput],
     speech_output: Optional[BatchedSpeechOutput],
-) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
-    adjusted_text_output: List[StringLike] = []
+) -> Tuple[List[str], Optional[BatchedSpeechOutput]]:
+    adjusted_text_output: List[str] = []
     adjusted_speech_output: Optional[BatchedSpeechOutput] = None
     adjusted_speech_output: Optional[BatchedSpeechOutput] = None
 
 
     if speech_output is not None:
     if speech_output is not None:

+ 1 - 1
src/seamless_communication/cli/m4t/finetune/dataloader.py

@@ -83,7 +83,7 @@ class BatchingConfig:
     """Select between fp16/fp32 for float tensors """
     """Select between fp16/fp32 for float tensors """
 
 
 
 
-def worker_init_fn(worker_id):
+def worker_init_fn(worker_id) -> None:
     np.random.seed(np.random.get_state()[1][0] + worker_id)
     np.random.seed(np.random.get_state()[1][0] + worker_id)
 
 
 
 

+ 8 - 8
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -116,12 +116,12 @@ class UnitYFinetuneWrapper(nn.Module):
                 unit_encoder_out,
                 unit_encoder_out,
                 unit_encoder_padding_mask,
                 unit_encoder_padding_mask,
             ) = self.model.t2u_model.encode(
             ) = self.model.t2u_model.encode(
-                text_decoder_output=text_decoder_out,
-                text_decoder_padding_mask=text_decoder_padding_mask,
+                text_decoder_out,
+                text_decoder_padding_mask,
             )
             )
             seqs = batch.text_to_units.prev_output_tokens.to(self.device)
             seqs = batch.text_to_units.prev_output_tokens.to(self.device)
             seq_lens = batch.text_to_units.target_lengths.to(self.device)
             seq_lens = batch.text_to_units.target_lengths.to(self.device)
-            unit_decoder_out, _ = self.model.t2u_model.decode(
+            unit_decoder_out = self.model.t2u_model.decode(
                 seqs=seqs,
                 seqs=seqs,
                 padding_mask=PaddingMask(seq_lens, seqs.size(1)),
                 padding_mask=PaddingMask(seq_lens, seqs.size(1)),
                 encoder_output=unit_encoder_out,
                 encoder_output=unit_encoder_out,
@@ -156,7 +156,7 @@ class CalcLoss:
             text_logits.device
             text_logits.device
         )
         )
         s2t_loss = SequenceModelOutput(
         s2t_loss = SequenceModelOutput(
-            logits=text_logits, vocab_info=self.s2t_vocab_info
+            text_logits, self.s2t_vocab_info.pad_idx
         ).compute_loss(
         ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             ignore_prefix_size=1,
             ignore_prefix_size=1,
@@ -167,7 +167,7 @@ class CalcLoss:
         assert batch.text_to_units.target_lengths is not None
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
         s2u_loss = SequenceModelOutput(
-            logits=unit_logits, vocab_info=self.t2u_vocab_info
+            logits=unit_logits, vocab_info=self.t2u_vocab_info.pad_idx
         ).compute_loss(
         ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             ignore_prefix_size=1,
             ignore_prefix_size=1,
@@ -314,7 +314,7 @@ class UnitYFinetune:
         eval_loss = loss_hist.reduce()
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
         self._update_eval_stats(eval_loss)
 
 
-    def _train_step_log(self):
+    def _train_step_log(self) -> None:
         """Log train stats"""
         """Log train stats"""
         if (self.update_idx + 1) % self.params.log_steps == 0:
         if (self.update_idx + 1) % self.params.log_steps == 0:
             avg_loss = self.train_loss_hist.reduce()
             avg_loss = self.train_loss_hist.reduce()
@@ -340,7 +340,7 @@ class UnitYFinetune:
         self.train_loss_hist.update(1, loss.item())
         self.train_loss_hist.update(1, loss.item())
         self._train_step_log()
         self._train_step_log()
 
 
-    def _save_model(self):
+    def _save_model(self) -> None:
         logger.info("Saving model")
         logger.info("Saving model")
         if dist_utils.is_main_process():
         if dist_utils.is_main_process():
             state_dict = {
             state_dict = {
@@ -351,7 +351,7 @@ class UnitYFinetune:
         if dist_utils.is_dist_initialized():
         if dist_utils.is_dist_initialized():
             dist.barrier()
             dist.barrier()
 
 
-    def run(self):
+    def run(self) -> None:
         logger.info("Start finetuning")
         logger.info("Start finetuning")
         self._reset_stats()
         self._reset_stats()
         self._eval_model()
         self._eval_model()

+ 1 - 1
src/seamless_communication/cli/toxicity/asr_etox.py

@@ -207,7 +207,7 @@ def build_data_pipeline(
 
 
     pipeline_builder = read_text(data_file, rtrim=True).skip(1).map(split_tsv)
     pipeline_builder = read_text(data_file, rtrim=True).skip(1).map(split_tsv)
 
 
-    map_file = FileMapper(root_dir=audio_root_dir, cached_fd_count=10)
+    map_file = FileMapper(root_dir=Path(audio_root_dir), cached_fd_count=10)
 
 
     pipeline_builder.map(
     pipeline_builder.map(
         map_file,
         map_file,

+ 5 - 3
src/seamless_communication/inference/generator.py

@@ -8,7 +8,7 @@ from dataclasses import dataclass
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
 
 
 import torch
 import torch
-from fairseq2.data import SequenceData, StringLike
+from fairseq2.data import SequenceData
 from fairseq2.data.text import TextTokenizer
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
 from fairseq2.generation import (
     BeamSearchSeq2SeqGenerator,
     BeamSearchSeq2SeqGenerator,
@@ -137,6 +137,7 @@ class UnitYGenerator:
             decoder_frontend=model.text_decoder_frontend,
             decoder_frontend=model.text_decoder_frontend,
             decoder=model.text_decoder,
             decoder=model.text_decoder,
             final_proj=model.final_proj,
             final_proj=model.final_proj,
+            max_target_seq_len=model.max_target_seq_len,
             target_vocab_info=model.target_vocab_info,
             target_vocab_info=model.target_vocab_info,
         )
         )
 
 
@@ -169,6 +170,7 @@ class UnitYGenerator:
                 decoder_frontend=model.text_decoder_frontend,
                 decoder_frontend=model.text_decoder_frontend,
                 decoder=model.text_decoder,
                 decoder=model.text_decoder,
                 final_proj=model.final_proj,
                 final_proj=model.final_proj,
+                max_target_seq_len=model.max_target_seq_len,
                 target_vocab_info=model.target_vocab_info,
                 target_vocab_info=model.target_vocab_info,
             )
             )
             generator = BeamSearchSeq2SeqGenerator(
             generator = BeamSearchSeq2SeqGenerator(
@@ -234,7 +236,7 @@ class UnitYGenerator:
         ngram_filtering: bool = False,
         ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
         prosody_encoder_input: Optional[SequenceData] = None,
-    ) -> Tuple[List[StringLike], Optional[Tensor]]:
+    ) -> Tuple[List[str], Optional[Tensor]]:
         """
         """
         :param source_seqs:
         :param source_seqs:
             The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
             The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
@@ -346,7 +348,7 @@ class UnitYGenerator:
             unit_seqs = t2u_model_output.logits.argmax(dim=2)
             unit_seqs = t2u_model_output.logits.argmax(dim=2)
             # Apply the padding mask to the generated units.
             # Apply the padding mask to the generated units.
             unit_seqs = apply_padding_mask(
             unit_seqs = apply_padding_mask(
-                unit_seqs, decoder_padding_mask, t2u_model_output.vocab_info.pad_idx
+                unit_seqs, decoder_padding_mask, t2u_model_output.pad_idx
             )
             )
 
 
         # Convert to speech units.
         # Convert to speech units.

+ 4 - 4
src/seamless_communication/inference/translator.py

@@ -13,7 +13,7 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from fairseq2.assets import asset_store
 from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
-from fairseq2.data import Collater, SequenceData, StringLike
+from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import TextTokenizer
 from fairseq2.data.text import TextTokenizer
 from fairseq2.memory import MemoryBlock
 from fairseq2.memory import MemoryBlock
@@ -169,7 +169,7 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
         prosody_encoder_input: Optional[SequenceData] = None,
-    ) -> Tuple[List[StringLike], Optional[Tensor]]:
+    ) -> Tuple[List[str], Optional[Tensor]]:
         # We disregard unit generations opts for the NAR T2U decoder.
         # We disregard unit generations opts for the NAR T2U decoder.
         if output_modality != Modality.SPEECH or isinstance(
         if output_modality != Modality.SPEECH or isinstance(
             model.t2u_model, UnitYNART2UModel
             model.t2u_model, UnitYNART2UModel
@@ -226,8 +226,8 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
         prosody_encoder_input: Optional[SequenceData] = None,
-        src_text: Optional[StringLike] = None,
-    ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
+        src_text: Optional[str] = None,
+    ) -> Tuple[List[str], Optional[BatchedSpeechOutput]]:
         """
         """
         The main method used to perform inference on all tasks.
         The main method used to perform inference on all tasks.
 
 

+ 4 - 5
src/seamless_communication/models/aligner/alignment_extractor.py

@@ -12,7 +12,6 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 import torchaudio
 import torchaudio
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
-from fairseq2.data.typing import StringLike
 from torch import Tensor
 from torch import Tensor
 
 
 from seamless_communication.models.aligner.loader import load_unity2_alignment_model
 from seamless_communication.models.aligner.loader import load_unity2_alignment_model
@@ -82,7 +81,7 @@ class AlignmentExtractor(nn.Module):
             audio = audio.mean(0)
             audio = audio.mean(0)
         assert (
         assert (
             audio.ndim == 1
             audio.ndim == 1
-        ), f"After channel averaging audio shape expected to be [Time] i.e. mono audio"
+        ), "After channel averaging audio shape expected to be [Time] i.e. mono audio"
         audio = audio.to(self.device, self.dtype)
         audio = audio.to(self.device, self.dtype)
 
 
         return audio
         return audio
@@ -101,7 +100,7 @@ class AlignmentExtractor(nn.Module):
         text: str,
         text: str,
         plot: bool = False,
         plot: bool = False,
         add_trailing_silence: bool = False,
         add_trailing_silence: bool = False,
-    ) -> Tuple[Tensor, Tensor, List[StringLike]]:
+    ) -> Tuple[Tensor, Tensor, List[str]]:
         if isinstance(audio, Tensor) and not torch.is_floating_point(audio):
         if isinstance(audio, Tensor) and not torch.is_floating_point(audio):
             # we got units as audio arg
             # we got units as audio arg
             units = audio
             units = audio
@@ -137,11 +136,11 @@ class AlignmentExtractor(nn.Module):
 
 
         return alignment_durations, tokenized_text_ids, tokenized_text_tokens
         return alignment_durations, tokenized_text_ids, tokenized_text_tokens
 
 
-    def detokenize_text(self, tokenized_text_ids: Tensor) -> StringLike:
+    def detokenize_text(self, tokenized_text_ids: Tensor) -> str:
         return self.alignment_model.alignment_frontend.decode_text(tokenized_text_ids)
         return self.alignment_model.alignment_frontend.decode_text(tokenized_text_ids)
 
 
     def plot_alignment(
     def plot_alignment(
-        self, audio: Tensor, text_tokens: List[StringLike], durations: Tensor
+        self, audio: Tensor, text_tokens: List[str], durations: Tensor
     ) -> None:
     ) -> None:
         if not matplotlib_available:
         if not matplotlib_available:
             raise RuntimeError(
             raise RuntimeError(

+ 8 - 7
src/seamless_communication/models/aligner/builder.py

@@ -10,9 +10,9 @@ from typing import Optional, Union
 import torch
 import torch
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data.vocabulary_info import VocabularyInfo
 from fairseq2.data.vocabulary_info import VocabularyInfo
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
-from fairseq2.typing import DataType, Device
+from fairseq2.typing import CPU, DataType, Device
 
 
 from seamless_communication.models.aligner.model import (
 from seamless_communication.models.aligner.model import (
     UnitY2AlignmentEncoder,
     UnitY2AlignmentEncoder,
@@ -56,7 +56,7 @@ class UnitY2AlignmentConfig:
     alignment_frontend_config: UnitY2AlignmentFrontendConfig
     alignment_frontend_config: UnitY2AlignmentFrontendConfig
 
 
 
 
-aligner_archs = ArchitectureRegistry[UnitY2AlignmentConfig]("unity2_aligner")
+aligner_archs = ModelArchitectureRegistry[UnitY2AlignmentConfig]()
 
 
 aligner_arch = aligner_archs.decorator
 aligner_arch = aligner_archs.decorator
 
 
@@ -90,14 +90,14 @@ def _aligner_nar_t2u() -> UnitY2AlignmentConfig:
 class UnitY2AlignmentBuilder:
 class UnitY2AlignmentBuilder:
     config: UnitY2AlignmentConfig
     config: UnitY2AlignmentConfig
     device: Optional[Device]
     device: Optional[Device]
-    dtype: DataType
+    dtype: Optional[DataType]
 
 
     def __init__(
     def __init__(
         self,
         self,
         config: UnitY2AlignmentConfig,
         config: UnitY2AlignmentConfig,
         *,
         *,
         device: Optional[Device] = None,
         device: Optional[Device] = None,
-        dtype: DataType = torch.float32,
+        dtype: Optional[DataType] = torch.float32,
     ) -> None:
     ) -> None:
         """
         """
         :param config:
         :param config:
@@ -155,7 +155,8 @@ class UnitY2AlignmentBuilder:
             dropout=cfg.dropout,
             dropout=cfg.dropout,
             temperature=cfg.temperature,
             temperature=cfg.temperature,
             reduction_factor=cfg.reduction_factor,
             reduction_factor=cfg.reduction_factor,
-            dtype=self.dtype,
+            device=self.device or CPU,
+            dtype=self.dtype or torch.float32,
         )
         )
         alignment_encoder.training = training
         alignment_encoder.training = training
 
 
@@ -165,7 +166,7 @@ class UnitY2AlignmentBuilder:
 def create_unity2_alignment_model(
 def create_unity2_alignment_model(
     config: UnitY2AlignmentConfig,
     config: UnitY2AlignmentConfig,
     device: Optional[Device] = None,
     device: Optional[Device] = None,
-    dtype: DataType = torch.float32,
+    dtype: Optional[DataType] = torch.float32,
 ) -> UnitY2AlignmentModel:
 ) -> UnitY2AlignmentModel:
     """Create a UnitY model.
     """Create a UnitY model.
 
 

+ 9 - 14
src/seamless_communication/models/aligner/loader.py

@@ -4,24 +4,23 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, List, Mapping
+from typing import Any, List, Dict
 
 
 import torch
 import torch
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 
 
 from seamless_communication.models.aligner.builder import (
 from seamless_communication.models.aligner.builder import (
     UnitY2AlignmentConfig,
     UnitY2AlignmentConfig,
     aligner_archs,
     aligner_archs,
     create_unity2_alignment_model,
     create_unity2_alignment_model,
 )
 )
-from seamless_communication.models.aligner.model import UnitY2AlignmentModel
+from seamless_communication.models.aligner.model import UNITY2_ALIGNER_FAMILY
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 
 
 
 
 def convert_unity2_aligner_checkpoint(
 def convert_unity2_aligner_checkpoint(
-    checkpoint: Mapping[str, Any], config: UnitY2AlignmentConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: UnitY2AlignmentConfig
+) -> Dict[str, Any]:
     if (
     if (
         "model" in checkpoint
         "model" in checkpoint
         and "alignment_encoder.t_conv.1.weight" in checkpoint["model"]
         and "alignment_encoder.t_conv.1.weight" in checkpoint["model"]
@@ -74,15 +73,11 @@ def _get_char_index_mapping(config: UnitY2AlignmentConfig) -> List[int]:
     return model_to_dict_mapping
     return model_to_dict_mapping
 
 
 
 
-load_unity2_alignment_config = ConfigLoader[UnitY2AlignmentConfig](
-    asset_store, aligner_archs
-)
-
-load_unity2_alignment_model = ModelLoader[UnitY2AlignmentModel, UnitY2AlignmentConfig](
-    asset_store,
-    download_manager,
-    load_unity2_alignment_config,
+load_unity2_alignment_model, load_unity2_alignment_config = setup_model_family(
+    UNITY2_ALIGNER_FAMILY,
+    UnitY2AlignmentConfig,
     create_unity2_alignment_model,
     create_unity2_alignment_model,
+    aligner_archs,
     convert_unity2_aligner_checkpoint,
     convert_unity2_aligner_checkpoint,
     restrict_checkpoints=False,
     restrict_checkpoints=False,
 )
 )

+ 29 - 10
src/seamless_communication/models/aligner/model.py

@@ -4,25 +4,27 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, List, Tuple, Union
+from typing import Any, Final, List, Tuple, Union
 
 
 import numpy as np
 import numpy as np
 import numpy.typing as npt
 import numpy.typing as npt
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
-from fairseq2.data import CString
+from fairseq2.models import Model
 from fairseq2.nn.embedding import StandardEmbedding
 from fairseq2.nn.embedding import StandardEmbedding
 from fairseq2.nn.padding import to_padding_mask
 from fairseq2.nn.padding import to_padding_mask
-from fairseq2.typing import DataType
+from fairseq2.typing import DataType, Device
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Module
 from torch.nn import Module
 
 
 from seamless_communication.models.unity.char_tokenizer import CharTokenizer
 from seamless_communication.models.unity.char_tokenizer import CharTokenizer
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 
 
+UNITY2_ALIGNER_FAMILY: Final = "unity2_aligner"
 
 
-class UnitY2AlignmentFrontend(Module):
+
+class UnitY2AlignmentFrontend(nn.Module):
     def __init__(
     def __init__(
         self,
         self,
         embed_text: StandardEmbedding,
         embed_text: StandardEmbedding,
@@ -53,7 +55,7 @@ class UnitY2AlignmentFrontend(Module):
 
 
     def tokenize_text_to_tokens(
     def tokenize_text_to_tokens(
         self, text: str, add_trailing_silence: bool = False
         self, text: str, add_trailing_silence: bool = False
-    ) -> List[Union[CString, str]]:
+    ) -> List[str]:
         tokenized = self.encode_text.encode_as_tokens(text)
         tokenized = self.encode_text.encode_as_tokens(text)
         if add_trailing_silence:
         if add_trailing_silence:
             tokenized = tokenized + [tokenized[0]]
             tokenized = tokenized + [tokenized[0]]
@@ -90,6 +92,7 @@ class UnitY2AlignmentEncoder(Module):
         dropout: float,
         dropout: float,
         temperature: float,
         temperature: float,
         reduction_factor: int,
         reduction_factor: int,
+        device: Device,
         dtype: DataType,
         dtype: DataType,
     ):
     ):
         super().__init__()
         super().__init__()
@@ -101,7 +104,12 @@ class UnitY2AlignmentEncoder(Module):
             if i < text_layers - 1:
             if i < text_layers - 1:
                 layers.append(
                 layers.append(
                     nn.Conv1d(
                     nn.Conv1d(
-                        embed_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
+                        embed_dim,
+                        embed_dim,
+                        kernel_size=3,
+                        padding=1,
+                        device=device,
+                        dtype=dtype,
                     )
                     )
                 )
                 )
                 layers.append(nn.ReLU())
                 layers.append(nn.ReLU())
@@ -109,7 +117,12 @@ class UnitY2AlignmentEncoder(Module):
             else:
             else:
                 layers.append(
                 layers.append(
                     nn.Conv1d(
                     nn.Conv1d(
-                        embed_dim, embed_dim, kernel_size=1, padding=0, dtype=dtype
+                        embed_dim,
+                        embed_dim,
+                        kernel_size=1,
+                        padding=0,
+                        device=device,
+                        dtype=dtype,
                     )
                     )
                 )
                 )
                 layers.append(nn.Dropout(p=dropout))
                 layers.append(nn.Dropout(p=dropout))
@@ -122,7 +135,12 @@ class UnitY2AlignmentEncoder(Module):
             if i < feat_layers - 1:
             if i < feat_layers - 1:
                 layers.append(
                 layers.append(
                     nn.Conv1d(
                     nn.Conv1d(
-                        input_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
+                        input_dim,
+                        embed_dim,
+                        kernel_size=3,
+                        padding=1,
+                        device=device,
+                        dtype=dtype,
                     )
                     )
                 )
                 )
                 layers.append(nn.ReLU())
                 layers.append(nn.ReLU())
@@ -135,6 +153,7 @@ class UnitY2AlignmentEncoder(Module):
                         kernel_size=1,
                         kernel_size=1,
                         padding=0,
                         padding=0,
                         stride=reduction_factor,
                         stride=reduction_factor,
+                        device=device,
                         dtype=dtype,
                         dtype=dtype,
                     )
                     )
                 )
                 )
@@ -277,7 +296,7 @@ def viterbi_decode(
     return durations
     return durations
 
 
 
 
-class UnitY2AlignmentModel(Module):
+class UnitY2AlignmentModel(Model):
     alignment_encoder: UnitY2AlignmentEncoder
     alignment_encoder: UnitY2AlignmentEncoder
     alignment_frontend: UnitY2AlignmentFrontend
     alignment_frontend: UnitY2AlignmentFrontend
 
 
@@ -286,7 +305,7 @@ class UnitY2AlignmentModel(Module):
         alignment_frontend: UnitY2AlignmentFrontend,
         alignment_frontend: UnitY2AlignmentFrontend,
         alignment_encoder: UnitY2AlignmentEncoder,
         alignment_encoder: UnitY2AlignmentEncoder,
     ):
     ):
-        super().__init__()
+        super().__init__(UNITY2_ALIGNER_FAMILY)
         self.alignment_frontend = alignment_frontend
         self.alignment_frontend = alignment_frontend
         self.alignment_encoder = alignment_encoder
         self.alignment_encoder = alignment_encoder
 
 

+ 34 - 19
src/seamless_communication/models/conformer_shaw/builder.py

@@ -4,13 +4,13 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from dataclasses import asdict, dataclass
-from typing import Optional
+from dataclasses import asdict, dataclass, field
+from typing import Final, Optional
 
 
 from fairseq2.models.conformer import ConformerConvolution
 from fairseq2.models.conformer import ConformerConvolution
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.w2vbert import w2vbert_archs
-from fairseq2.models.wav2vec2.builder import (
+from fairseq2.models.wav2vec2 import (
     Wav2Vec2Builder,
     Wav2Vec2Builder,
     Wav2Vec2Config,
     Wav2Vec2Config,
     Wav2Vec2EncoderBuilder,
     Wav2Vec2EncoderBuilder,
@@ -21,15 +21,17 @@ from fairseq2.models.wav2vec2.model import Wav2Vec2Model
 from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA, create_default_sdpa
 from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA, create_default_sdpa
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 
 
+CONFORMER_SHAW_FAMILY: Final = "conformer_shaw"
+
 
 
 @dataclass
 @dataclass
 class ShawRelativePositionSDPAConfig:
 class ShawRelativePositionSDPAConfig:
     """Holds the configuration of the :class:ShawRelativePositionSDPA module."""
     """Holds the configuration of the :class:ShawRelativePositionSDPA module."""
 
 
-    max_left_rel_pos: int
+    max_left_rel_pos: int = 64
     """The left clipping value for relative positions."""
     """The left clipping value for relative positions."""
 
 
-    max_right_rel_pos: Optional[int]
+    max_right_rel_pos: Optional[int] = 8
     """The right clipping value for relative positions."""
     """The right clipping value for relative positions."""
 
 
     use_rel_pos_values: bool = False
     use_rel_pos_values: bool = False
@@ -40,18 +42,23 @@ class ShawRelativePositionSDPAConfig:
 class ConformerShawEncoderConfig(Wav2Vec2EncoderConfig):
 class ConformerShawEncoderConfig(Wav2Vec2EncoderConfig):
     """Holds the configuration of a conformer shaw encoder."""
     """Holds the configuration of a conformer shaw encoder."""
 
 
-    shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
+    shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig] = None
     """The parameters for ShawRelativePositionSDPA."""
     """The parameters for ShawRelativePositionSDPA."""
 
 
 
 
-conformer_shaw_archs = ArchitectureRegistry[ConformerShawEncoderConfig](
-    "conformer_shaw"
-)
+@dataclass
+class ConformerShawConfig(Wav2Vec2Config):
+    """Holds the configuration of a conformer shaw model."""
+
+    encoder_config: ConformerShawEncoderConfig = field(
+        default_factory=ConformerShawEncoderConfig
+    )
 
 
-conformer_shaw_arch = conformer_shaw_archs.decorator
 
 
+conformer_shaw_archs = ModelArchitectureRegistry[ConformerShawConfig]()
+
+conformer_shaw_arch = conformer_shaw_archs.decorator
 
 
-@conformer_shaw_arch("600m")
 def _conformer_shaw_600m_encoder() -> ConformerShawEncoderConfig:
 def _conformer_shaw_600m_encoder() -> ConformerShawEncoderConfig:
     w2vbert_config = w2vbert_archs.get_config("600m")
     w2vbert_config = w2vbert_archs.get_config("600m")
     w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
     w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
@@ -68,18 +75,20 @@ def _conformer_shaw_600m_encoder() -> ConformerShawEncoderConfig:
     return conformer_shaw_encoder_config
     return conformer_shaw_encoder_config
 
 
 
 
-@wav2vec2_arch("conformer_shaw_600m")
-def _conformer_shaw_600m() -> Wav2Vec2Config:
+@conformer_shaw_arch("conformer_shaw_600m")
+def _conformer_shaw_600m() -> ConformerShawConfig:
     encoder_config = _conformer_shaw_600m_encoder()
     encoder_config = _conformer_shaw_600m_encoder()
 
 
-    return Wav2Vec2Config(
+    return ConformerShawConfig(
         encoder_config,
         encoder_config,
         final_dim=768,
         final_dim=768,
         final_proj_bias=True,
         final_proj_bias=True,
         temporal_mask_span_len=10,
         temporal_mask_span_len=10,
         max_temporal_mask_prob=0.65,
         max_temporal_mask_prob=0.65,
+        min_num_temporal_mask_spans=2,
         spatial_mask_span_len=10,
         spatial_mask_span_len=10,
         max_spatial_mask_prob=0.0,
         max_spatial_mask_prob=0.0,
+        min_num_spatial_mask_spans=2,
         quantized_dim=768,
         quantized_dim=768,
         num_codebooks=2,
         num_codebooks=2,
         num_codebook_entries=320,
         num_codebook_entries=320,
@@ -101,6 +110,8 @@ class ConformerShawEncoderBuilder(Wav2Vec2EncoderBuilder):
     """
     """
 
 
     config: ConformerShawEncoderConfig
     config: ConformerShawEncoderConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -119,11 +130,15 @@ class ConformerShawEncoderBuilder(Wav2Vec2EncoderBuilder):
         """
         """
         super().__init__(config, device=device, dtype=dtype)
         super().__init__(config, device=device, dtype=dtype)
 
 
+        self.config = config
+
         assert self.config.use_conformer, "This architecture only supports a Conformer."
         assert self.config.use_conformer, "This architecture only supports a Conformer."
         assert (
         assert (
             self.config.pos_encoder_type == "shaw_relative"
             self.config.pos_encoder_type == "shaw_relative"
         ), "This architecture only supports ShawRelativePositionSDPA."
         ), "This architecture only supports ShawRelativePositionSDPA."
 
 
+        self.device, self.dtype = device, dtype
+
     def build_sdpa(self) -> SDPA:
     def build_sdpa(self) -> SDPA:
         if self.config.shaw_rel_pos_sdpa_config is None:
         if self.config.shaw_rel_pos_sdpa_config is None:
             raise ValueError(
             raise ValueError(
@@ -157,7 +172,7 @@ class ConformerShawEncoderBuilder(Wav2Vec2EncoderBuilder):
 
 
 
 
 def create_conformer_shaw_model(
 def create_conformer_shaw_model(
-    config: Wav2Vec2Config,
+    config: ConformerShawConfig,
     *,
     *,
     device: Optional[Device] = None,
     device: Optional[Device] = None,
     dtype: Optional[DataType] = None,
     dtype: Optional[DataType] = None,
@@ -171,12 +186,12 @@ def create_conformer_shaw_model(
     :param dtype:
     :param dtype:
         The data type of module parameters and buffers.
         The data type of module parameters and buffers.
     """
     """
-    assert isinstance(config.encoder_config, ConformerShawEncoderConfig)
-
     encoder_builder = ConformerShawEncoderBuilder(
     encoder_builder = ConformerShawEncoderBuilder(
         config.encoder_config, device=device, dtype=dtype
         config.encoder_config, device=device, dtype=dtype
     )
     )
 
 
-    builder = Wav2Vec2Builder(config, encoder_builder, device=device, dtype=dtype)
+    builder = Wav2Vec2Builder(
+        CONFORMER_SHAW_FAMILY, config, encoder_builder, device=device, dtype=dtype
+    )
 
 
     return builder.build_model()
     return builder.build_model()

+ 12 - 12
src/seamless_communication/models/conformer_shaw/loader.py

@@ -4,25 +4,25 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Mapping
+from typing import Any, Dict
 
 
 import torch
 import torch
 
 
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ModelLoader
+from fairseq2.models import setup_model_family
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
-from fairseq2.models.wav2vec2.builder import Wav2Vec2Config
-from fairseq2.models.wav2vec2.loader import load_wav2vec2_config
-from fairseq2.models.wav2vec2.model import Wav2Vec2Model
+from fairseq2.models.wav2vec2 import Wav2Vec2Model
 
 
 from seamless_communication.models.conformer_shaw.builder import (
 from seamless_communication.models.conformer_shaw.builder import (
+    CONFORMER_SHAW_FAMILY,
+    ConformerShawConfig,
+    conformer_shaw_archs,
     create_conformer_shaw_model,
     create_conformer_shaw_model,
 )
 )
 
 
 
 
 def convert_conformer_shaw_checkpoint(
 def convert_conformer_shaw_checkpoint(
-    checkpoint: Mapping[str, Any], config: Wav2Vec2Config
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: ConformerShawConfig
+) -> Dict[str, Any]:
     """Convert a fairseq conformer shaw checkpoint to fairseq2."""
     """Convert a fairseq conformer shaw checkpoint to fairseq2."""
     state_dict = checkpoint["model"]
     state_dict = checkpoint["model"]
 
 
@@ -73,10 +73,10 @@ def convert_conformer_shaw_checkpoint(
     return convert_fairseq_checkpoint(checkpoint, key_map)
     return convert_fairseq_checkpoint(checkpoint, key_map)
 
 
 
 
-load_conformer_shaw_model = ModelLoader[Wav2Vec2Model, Wav2Vec2Config](
-    asset_store,
-    download_manager,
-    load_wav2vec2_config,
+load_conformer_shaw_model, load_conformer_shaw_config = setup_model_family(
+    CONFORMER_SHAW_FAMILY,
+    ConformerShawConfig,
     create_conformer_shaw_model,
     create_conformer_shaw_model,
+    conformer_shaw_archs,
     convert_conformer_shaw_checkpoint,
     convert_conformer_shaw_checkpoint,
 )
 )

+ 2 - 3
src/seamless_communication/models/generator/builder.py

@@ -8,7 +8,7 @@ from dataclasses import dataclass
 from typing import Any, Dict, List, Literal, Optional, Tuple
 from typing import Any, Dict, List, Literal, Optional, Tuple
 
 
 from fairseq2.data import VocabularyInfo
 from fairseq2.data import VocabularyInfo
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.projection import Linear
@@ -110,8 +110,7 @@ class VocoderConfig:
     gcmvn_stats: Dict[str, List]  # type: ignore[type-arg]
     gcmvn_stats: Dict[str, List]  # type: ignore[type-arg]
 
 
 
 
-vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_pretssel")
-
+vocoder_archs = ModelArchitectureRegistry[VocoderConfig]()
 
 
 vocoder_arch = vocoder_archs.decorator
 vocoder_arch = vocoder_archs.decorator
 
 

+ 2 - 2
src/seamless_communication/models/generator/ecapa_tdnn_builder.py

@@ -7,7 +7,7 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import List, Optional
 from typing import List, Optional
 
 
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 
 
 from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
@@ -27,7 +27,7 @@ class EcapaTDNNConfig:
     input_dim: int
     input_dim: int
 
 
 
 
-ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+ecapa_tdnn_archs = ModelArchitectureRegistry[EcapaTDNNConfig]()
 
 
 ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
 ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
 
 

+ 6 - 12
src/seamless_communication/models/generator/loader.py

@@ -5,25 +5,19 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 
 
-from typing import Any, Mapping
-
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 
 
+from seamless_communication.models.generator.vocoder import PRETSSEL_VOCODER_FAMILY
 from seamless_communication.models.generator.builder import (
 from seamless_communication.models.generator.builder import (
     VocoderConfig,
     VocoderConfig,
     create_vocoder_model,
     create_vocoder_model,
     vocoder_archs,
     vocoder_archs,
 )
 )
-from seamless_communication.models.generator.vocoder import PretsselVocoder
-
-load_pretssel_vocoder_config = ConfigLoader[VocoderConfig](asset_store, vocoder_archs)
-
 
 
-load_pretssel_vocoder_model = ModelLoader[PretsselVocoder, VocoderConfig](
-    asset_store,
-    download_manager,
-    load_pretssel_vocoder_config,
+load_pretssel_vocoder_model, load_pretssel_vocoder_config = setup_model_family(
+    PRETSSEL_VOCODER_FAMILY,
+    VocoderConfig,
     create_vocoder_model,
     create_vocoder_model,
+    vocoder_archs,
     restrict_checkpoints=False,
     restrict_checkpoints=False,
 )
 )

+ 1 - 1
src/seamless_communication/models/generator/streamable.py

@@ -6,7 +6,7 @@
 
 
 import math
 import math
 import warnings
 import warnings
-from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar
+from typing import Any, Dict, List, Literal, Optional, Tuple
 
 
 import torch
 import torch
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device

+ 7 - 3
src/seamless_communication/models/generator/vocoder.py

@@ -4,10 +4,11 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Dict, List, Literal, Optional, Tuple
+from typing import Any, Dict, Final, List, Literal, Optional, Tuple
 
 
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
+from fairseq2.models import Model
 from fairseq2.nn.embedding import Embedding, StandardEmbedding
 from fairseq2.nn.embedding import Embedding, StandardEmbedding
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder
 from fairseq2.nn.position_encoder import PositionEncoder
@@ -44,6 +45,9 @@ from .streamable import (
     StreamableResnetBlock,
     StreamableResnetBlock,
 )
 )
 
 
+
+PRETSSEL_VOCODER_FAMILY: Final = "vocoder_pretssel"
+
 ELU_PARAMS: Dict[str, Any] = {"alpha": 1.0}
 ELU_PARAMS: Dict[str, Any] = {"alpha": 1.0}
 
 
 
 
@@ -162,7 +166,7 @@ class PretsselDecoderFrontend(Module):
         return seqs, padding_mask
         return seqs, padding_mask
 
 
 
 
-class PretsselVocoder(Module):
+class PretsselVocoder(Model):
     """The expressivity-preserving vocoder"""
     """The expressivity-preserving vocoder"""
 
 
     encoder_frontend: PretsselEncoderFrontend
     encoder_frontend: PretsselEncoderFrontend
@@ -212,7 +216,7 @@ class PretsselVocoder(Module):
         device: Optional[Device] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
         dtype: Optional[DataType] = None,
     ):
     ):
-        super().__init__()
+        super().__init__(PRETSSEL_VOCODER_FAMILY)
         self.encoder_frontend = encoder_frontend
         self.encoder_frontend = encoder_frontend
         self.encoder = encoder
         self.encoder = encoder
         self.decoder_frontend = decoder_frontend
         self.decoder_frontend = decoder_frontend

+ 2 - 4
src/seamless_communication/models/monotonic_decoder/builder.py

@@ -12,7 +12,7 @@ from fairseq2.models.transformer import (
     TransformerEmbeddingFrontend,
     TransformerEmbeddingFrontend,
     TransformerFrontend,
     TransformerFrontend,
 )
 )
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.projection import TiedProjection
@@ -77,9 +77,7 @@ class MonotonicDecoderConfig:
     in the PChooseLayer."""
     in the PChooseLayer."""
 
 
 
 
-monotonic_decoder_archs = ArchitectureRegistry[MonotonicDecoderConfig](
-    "monotonic_decoder"
-)
+monotonic_decoder_archs = ModelArchitectureRegistry[MonotonicDecoderConfig]()
 
 
 monotonic_decoder_arch = monotonic_decoder_archs.decorator
 monotonic_decoder_arch = monotonic_decoder_archs.decorator
 
 

+ 9 - 17
src/seamless_communication/models/monotonic_decoder/loader.py

@@ -4,11 +4,10 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Mapping
+from typing import Any, Dict
 
 
 import torch
 import torch
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 
 
 from seamless_communication.models.monotonic_decoder.builder import (
 from seamless_communication.models.monotonic_decoder.builder import (
@@ -16,12 +15,12 @@ from seamless_communication.models.monotonic_decoder.builder import (
     create_monotonic_decoder_model,
     create_monotonic_decoder_model,
     monotonic_decoder_archs,
     monotonic_decoder_archs,
 )
 )
-from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
+from seamless_communication.models.monotonic_decoder.model import MONOTONIC_DECODER_FAMILY
 
 
 
 
 def convert_monotonic_checkpoint(
 def convert_monotonic_checkpoint(
-    checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: MonotonicDecoderConfig
+) -> Dict[str, Any]:
     state_dict = checkpoint["model"]
     state_dict = checkpoint["model"]
 
 
     # Check if we have a fairseq2 checkpoint.
     # Check if we have a fairseq2 checkpoint.
@@ -75,18 +74,11 @@ def convert_monotonic_checkpoint(
     return checkpoint
     return checkpoint
 
 
 
 
-load_monotonic_decoder_config = ConfigLoader[MonotonicDecoderConfig](
-    asset_store, monotonic_decoder_archs
-)
-
-
-load_monotonic_decoder_model = ModelLoader[
-    MonotonicDecoderModel, MonotonicDecoderConfig
-](
-    asset_store,
-    download_manager,
-    load_monotonic_decoder_config,
+load_monotonic_decoder_model, load_monotonic_decoder_config = setup_model_family(
+    MONOTONIC_DECODER_FAMILY,
+    MonotonicDecoderConfig,
     create_monotonic_decoder_model,
     create_monotonic_decoder_model,
+    monotonic_decoder_archs,
     convert_monotonic_checkpoint,
     convert_monotonic_checkpoint,
     restrict_checkpoints=False,
     restrict_checkpoints=False,
 )
 )

+ 10 - 7
src/seamless_communication/models/monotonic_decoder/model.py

@@ -4,23 +4,26 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Optional, Tuple, final
+from typing import Final, Optional, Tuple, final
 
 
+from fairseq2.models import Model
 from fairseq2.models.transformer.frontend import TransformerFrontend
 from fairseq2.models.transformer.frontend import TransformerFrontend
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.projection import Projection
 from fairseq2.nn.projection import Projection
-from overrides import final as finaloverride
+from overrides import final as override
 from torch import Tensor
 from torch import Tensor
-from torch.nn import Module
 
 
 from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
 from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
     MonotonicTransformerDecoder,
     MonotonicTransformerDecoder,
 )
 )
 
 
 
 
+MONOTONIC_DECODER_FAMILY: Final = "monotonic_decoder"
+
+
 @final
 @final
-class MonotonicDecoderModel(Module):
+class MonotonicDecoderModel(Model):
     text_decoder_frontend: TransformerFrontend
     text_decoder_frontend: TransformerFrontend
     text_decoder: MonotonicTransformerDecoder
     text_decoder: MonotonicTransformerDecoder
     final_proj: Projection
     final_proj: Projection
@@ -31,13 +34,13 @@ class MonotonicDecoderModel(Module):
         text_decoder: MonotonicTransformerDecoder,
         text_decoder: MonotonicTransformerDecoder,
         final_proj: Projection,
         final_proj: Projection,
     ) -> None:
     ) -> None:
-        super().__init__()
+        super().__init__(MONOTONIC_DECODER_FAMILY)
 
 
         self.text_decoder_frontend = text_decoder_frontend
         self.text_decoder_frontend = text_decoder_frontend
         self.text_decoder = text_decoder
         self.text_decoder = text_decoder
         self.final_proj = final_proj
         self.final_proj = final_proj
 
 
-    @finaloverride
+    @override
     def decode(
     def decode(
         self,
         self,
         seqs: Tensor,
         seqs: Tensor,
@@ -59,7 +62,7 @@ class MonotonicDecoderModel(Module):
             state_bag=state_bag,
             state_bag=state_bag,
         )
         )
 
 
-    @finaloverride
+    @override
     def project(self, decoder_output: Tensor) -> Tensor:
     def project(self, decoder_output: Tensor) -> Tensor:
         logits = self.final_proj(decoder_output)
         logits = self.final_proj(decoder_output)
 
 

+ 2 - 2
src/seamless_communication/models/monotonic_decoder/monotonic_decoder.py

@@ -16,7 +16,7 @@ from fairseq2.nn.transformer import (
     CausalAttentionMaskFactory,
     CausalAttentionMaskFactory,
     create_standard_layer_norm,
     create_standard_layer_norm,
 )
 )
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Module
 from torch.nn import Module
 
 
@@ -62,7 +62,7 @@ class MonotonicTransformerDecoder(Module):
             self.model_dim, device=device, dtype=dtype
             self.model_dim, device=device, dtype=dtype
         )
         )
 
 
-    @finaloverride
+    @override
     def forward(
     def forward(
         self,
         self,
         seqs: Tensor,
         seqs: Tensor,

+ 2 - 2
src/seamless_communication/models/monotonic_decoder/monotonic_decoder_layer.py

@@ -15,7 +15,7 @@ from fairseq2.nn.transformer import (
     MultiheadAttention,
     MultiheadAttention,
     create_standard_layer_norm,
     create_standard_layer_norm,
 )
 )
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Dropout, Module
 from torch.nn import Dropout, Module
 
 
@@ -104,7 +104,7 @@ class MonotonicTransformerDecoderLayer(Module):
         else:
         else:
             self.register_module("ffn_dropout", None)
             self.register_module("ffn_dropout", None)
 
 
-    @finaloverride
+    @override
     def forward(
     def forward(
         self,
         self,
         seqs: Tensor,
         seqs: Tensor,

+ 2 - 2
src/seamless_communication/models/monotonic_decoder/p_choose.py

@@ -8,7 +8,7 @@ from typing import Optional, final
 
 
 import torch
 import torch
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.projection import Linear
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch import Tensor
 from torch.nn import AvgPool1d, Module, ModuleList, ReLU
 from torch.nn import AvgPool1d, Module, ModuleList, ReLU
 from torch.nn.parameter import Parameter
 from torch.nn.parameter import Parameter
@@ -116,7 +116,7 @@ class PChooseLayer(Module):
             ceil_mode=True,
             ceil_mode=True,
         )
         )
 
 
-    @finaloverride
+    @override
     def forward(self, seqs: Tensor, keys: Tensor) -> Tensor:
     def forward(self, seqs: Tensor, keys: Tensor) -> Tensor:
         q = self.q_energy_proj(seqs)
         q = self.q_energy_proj(seqs)
 
 

+ 2 - 2
src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py

@@ -7,7 +7,7 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import List, Optional
 from typing import List, Optional
 
 
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 
 
 from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
@@ -27,7 +27,7 @@ class EcapaTDNNConfig:
     input_dim: int
     input_dim: int
 
 
 
 
-ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+ecapa_tdnn_archs = ModelArchitectureRegistry[EcapaTDNNConfig]()
 
 
 ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
 ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
 
 

+ 12 - 29
src/seamless_communication/models/tokenizer.py

@@ -4,32 +4,26 @@
 # This source code is licensed under the BSD-style license found in the
 # This source code is licensed under the BSD-style license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
+from pathlib import Path
 from typing import Optional, Sequence, Set, final
 from typing import Optional, Sequence, Set, final
 
 
 from fairseq2.data.text import (
 from fairseq2.data.text import (
-    SentencePieceDecoder,
+    SentencePieceTokenizer,
     SentencePieceEncoder,
     SentencePieceEncoder,
-    SentencePieceModel,
-    TextTokenDecoder,
-    TextTokenEncoder,
-    TextTokenizer,
-    vocab_info_from_sentencepiece,
 )
 )
-from fairseq2.data.typing import PathLike
-from fairseq2.typing import Device, finaloverride
+from fairseq2.typing import Device, override
 
 
 
 
 @final
 @final
-class SPMTokenizer(TextTokenizer):
+class SPMTokenizer(SentencePieceTokenizer):
     """Represents standard SPM-based tokenizer used in MT tasks"""
     """Represents standard SPM-based tokenizer used in MT tasks"""
 
 
-    model: SentencePieceModel
     langs: Set[str]
     langs: Set[str]
     prepend_target_langtok_to_target: bool
     prepend_target_langtok_to_target: bool
 
 
     def __init__(
     def __init__(
         self,
         self,
-        pathname: PathLike,
+        path: Path,
         langs: Sequence[str],
         langs: Sequence[str],
         prepend_target_langtok_to_target: bool = True,
         prepend_target_langtok_to_target: bool = True,
     ) -> None:
     ) -> None:
@@ -41,20 +35,19 @@ class SPMTokenizer(TextTokenizer):
         :param default_lang:
         :param default_lang:
             The fall-back language if no language is specified.
             The fall-back language if no language is specified.
         """
         """
-        self.langs = set(langs)
-        self.prepend_target_langtok_to_target = prepend_target_langtok_to_target
-
         # Each language is represented by a `__lang__` control symbol.
         # Each language is represented by a `__lang__` control symbol.
         control_symbols = [self._lang_tok_to_internal(lang) for lang in sorted(langs)]
         control_symbols = [self._lang_tok_to_internal(lang) for lang in sorted(langs)]
-        self.model = SentencePieceModel(pathname, control_symbols)
-        vocab_info = vocab_info_from_sentencepiece(self.model)
-        super().__init__(vocab_info)
+
+        super().__init__(path, control_symbols)
+
+        self.langs = set(langs)
+        self.prepend_target_langtok_to_target = prepend_target_langtok_to_target
 
 
     @classmethod
     @classmethod
     def _lang_tok_to_internal(cls, lang: str) -> str:
     def _lang_tok_to_internal(cls, lang: str) -> str:
         return f"__{lang}__"
         return f"__{lang}__"
 
 
-    @finaloverride
+    @override
     def create_encoder(
     def create_encoder(
         self,
         self,
         *,
         *,
@@ -63,7 +56,7 @@ class SPMTokenizer(TextTokenizer):
         mode: Optional[str] = None,
         mode: Optional[str] = None,
         device: Optional[Device] = None,
         device: Optional[Device] = None,
         pin_memory: bool = False,
         pin_memory: bool = False,
-    ) -> TextTokenEncoder:
+    ) -> SentencePieceEncoder:
         """Create a token encoder.
         """Create a token encoder.
 
 
         :param task:
         :param task:
@@ -110,13 +103,3 @@ class SPMTokenizer(TextTokenizer):
             device=device,
             device=device,
             pin_memory=pin_memory,
             pin_memory=pin_memory,
         )
         )
-
-    @finaloverride
-    def create_raw_encoder(
-        self, *, device: Optional[Device] = None, pin_memory: bool = False
-    ) -> TextTokenEncoder:
-        return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)
-
-    @finaloverride
-    def create_decoder(self) -> TextTokenDecoder:
-        return SentencePieceDecoder(self.model)

+ 2 - 0
src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py

@@ -63,8 +63,10 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:
         final_proj_bias=True,
         final_proj_bias=True,
         temporal_mask_span_len=10,
         temporal_mask_span_len=10,
         max_temporal_mask_prob=0.65,
         max_temporal_mask_prob=0.65,
+        min_num_temporal_mask_spans=2,
         spatial_mask_span_len=10,
         spatial_mask_span_len=10,
         max_spatial_mask_prob=0.0,
         max_spatial_mask_prob=0.0,
+        min_num_spatial_mask_spans=2,
         quantized_dim=1024,
         quantized_dim=1024,
         num_codebooks=2,
         num_codebooks=2,
         num_codebook_entries=320,
         num_codebook_entries=320,

+ 62 - 28
src/seamless_communication/models/unity/builder.py

@@ -7,9 +7,10 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Optional, Union
 from typing import Optional, Union
 
 
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.projection import TiedProjection
@@ -36,7 +37,7 @@ from seamless_communication.models.unity.adaptor_block import (
     UnitYEncoderAdaptor,
     UnitYEncoderAdaptor,
     UnitYTransformerAdaptorLayer,
     UnitYTransformerAdaptorLayer,
 )
 )
-from seamless_communication.models.unity.model import UnitYModel
+from seamless_communication.models.unity.model import UNITY_FAMILY, UnitYModel
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
     UnitYNART2UBuilder,
     UnitYNART2UBuilder,
     UnitYT2UBuilder,
     UnitYT2UBuilder,
@@ -100,7 +101,7 @@ class UnitYConfig:
     """The dropout probability in Transformer layers of the adaptor block."""
     """The dropout probability in Transformer layers of the adaptor block."""
 
 
 
 
-unity_archs = ArchitectureRegistry[UnitYConfig]("unity")
+unity_archs = ModelArchitectureRegistry[UnitYConfig]()
 
 
 unity_arch = unity_archs.decorator
 unity_arch = unity_archs.decorator
 
 
@@ -111,7 +112,15 @@ def _base() -> UnitYConfig:
 
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
 
-    mt_model_config.vocab_info.size = 256102  # NLLB-100
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256102,  # NLLB-100
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
 
     t2u_config = unity_t2u_archs.get_config("base")
     t2u_config = unity_t2u_archs.get_config("base")
 
 
@@ -139,7 +148,15 @@ def _medium() -> UnitYConfig:
 
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_600m")
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_600m")
 
 
-    mt_model_config.vocab_info.size = 256206  # NLLB-200
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256206,  # NLLB-200
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
 
     t2u_config = unity_t2u_archs.get_config("medium")
     t2u_config = unity_t2u_archs.get_config("medium")
 
 
@@ -163,11 +180,19 @@ def _medium() -> UnitYConfig:
 
 
 @unity_arch("base_v2")
 @unity_arch("base_v2")
 def _base_v2() -> UnitYConfig:
 def _base_v2() -> UnitYConfig:
-    conformer_shaw_encoder_config = conformer_shaw_archs.get_config("600m")
+    conformer_shaw_config = conformer_shaw_archs.get_config("conformer_shaw_600m")
 
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
 
-    mt_model_config.vocab_info.size = 256102  # NLLB-100
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256102,  # NLLB-100
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
 
     mt_model_config.max_seq_len = 4096
     mt_model_config.max_seq_len = 4096
 
 
@@ -175,7 +200,7 @@ def _base_v2() -> UnitYConfig:
 
 
     return UnitYConfig(
     return UnitYConfig(
         model_dim=1024,
         model_dim=1024,
-        w2v2_encoder_config=conformer_shaw_encoder_config,
+        w2v2_encoder_config=conformer_shaw_config.encoder_config,
         mt_model_config=mt_model_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         t2u_config=t2u_config,
         prosody_encoder_config=None,
         prosody_encoder_config=None,
@@ -193,11 +218,19 @@ def _base_v2() -> UnitYConfig:
 
 
 @unity_arch("expressivity_v2")
 @unity_arch("expressivity_v2")
 def _expressivity_v2() -> UnitYConfig:
 def _expressivity_v2() -> UnitYConfig:
-    conformer_shaw_encoder_config = conformer_shaw_archs.get_config("600m")
+    conformer_shaw_config = conformer_shaw_archs.get_config("conformer_shaw_600m")
 
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
 
-    mt_model_config.vocab_info.size = 256102  # NLLB-100
+    vocab_info = mt_model_config.vocab_info
+
+    mt_model_config.vocab_info = VocabularyInfo(
+        size=256102,  # NLLB-100
+        unk_idx=vocab_info.unk_idx,
+        bos_idx=vocab_info.bos_idx,
+        eos_idx=vocab_info.eos_idx,
+        pad_idx=vocab_info.pad_idx,
+    )
 
 
     mt_model_config.max_seq_len = 10000
     mt_model_config.max_seq_len = 10000
 
 
@@ -207,7 +240,7 @@ def _expressivity_v2() -> UnitYConfig:
 
 
     return UnitYConfig(
     return UnitYConfig(
         model_dim=1024,
         model_dim=1024,
-        w2v2_encoder_config=conformer_shaw_encoder_config,
+        w2v2_encoder_config=conformer_shaw_config.encoder_config,
         mt_model_config=mt_model_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         t2u_config=t2u_config,
         prosody_encoder_config=prosody_encoder_config,
         prosody_encoder_config=prosody_encoder_config,
@@ -263,19 +296,19 @@ class UnitYBuilder:
         :param dtype:
         :param dtype:
             The data type of module parameters and buffers.
             The data type of module parameters and buffers.
         """
         """
-        if w2v2_encoder_builder.config.model_dim != config.model_dim:
+        if config.w2v2_encoder_config.model_dim != config.model_dim:
             raise ValueError(
             raise ValueError(
-                f"`model_dim` and `model_dim` of `w2v2_encoder_builder.config` must be equal, but are {config.model_dim} and {w2v2_encoder_builder.config.model_dim} instead."
+                f"`config.model_dim` and `config.w2v2_encoder_config.model_dim` must be equal, but are {config.model_dim} and {config.w2v2_encoder_config.model_dim} instead."
             )
             )
 
 
-        if mt_model_builder.config.model_dim != config.model_dim:
+        if config.mt_model_config.model_dim != config.model_dim:
             raise ValueError(
             raise ValueError(
-                f"`model_dim` and `model_dim` of `mt_model_builder.config` must be equal, but are {config.model_dim} and {mt_model_builder.config.model_dim} instead."
+                f"`config.model_dim` and `config.mt_model_config.model_dim` must be equal, but are {config.model_dim} and {config.mt_model_config.model_dim} instead."
             )
             )
 
 
-        if t2u_builder is not None and t2u_builder.config.model_dim != config.model_dim:
+        if config.t2u_config is not None and config.t2u_config.model_dim != config.model_dim:
             raise ValueError(
             raise ValueError(
-                f"`model_dim` and `model_dim` of `t2u_builder.config` must be equal, but are {config.model_dim} and {t2u_builder.config.model_dim} instead."
+                f"`config.model_dim` and `config.t2u_config.model_dim` must be equal, but are {config.model_dim} and {config.t2u_config.model_dim} instead."
             )
             )
 
 
         self.config = config
         self.config = config
@@ -337,6 +370,7 @@ class UnitYBuilder:
             text_decoder,
             text_decoder,
             final_proj,
             final_proj,
             t2u_model,
             t2u_model,
+            self.config.mt_model_config.max_seq_len or 0,
             self.config.mt_model_config.vocab_info,
             self.config.mt_model_config.vocab_info,
             prosody_encoder_model,
             prosody_encoder_model,
         )
         )
@@ -367,12 +401,12 @@ class UnitYBuilder:
     def build_adaptor_layer(self, idx: int) -> TransformerEncoderLayer:
     def build_adaptor_layer(self, idx: int) -> TransformerEncoderLayer:
         """Build a Transformer-based encoder adaptor layer."""
         """Build a Transformer-based encoder adaptor layer."""
         self_attn = self.build_adaptor_attention(
         self_attn = self.build_adaptor_attention(
-            self.w2v2_encoder_builder.config.num_encoder_attn_heads
+            self.config.w2v2_encoder_config.num_encoder_attn_heads
         )
         )
 
 
         ffn = StandardFeedForwardNetwork(
         ffn = StandardFeedForwardNetwork(
             self.config.model_dim,
             self.config.model_dim,
-            self.w2v2_encoder_builder.config.ffn_inner_dim,
+            self.config.w2v2_encoder_config.ffn_inner_dim,
             inner_activation=GELU() if self.config.use_gelu else ReLU(),
             inner_activation=GELU() if self.config.use_gelu else ReLU(),
             bias=True,
             bias=True,
             device=self.device,
             device=self.device,
@@ -396,12 +430,12 @@ class UnitYBuilder:
         # Empirically shown that, in adaptor layers, vanilla MHA performs better
         # Empirically shown that, in adaptor layers, vanilla MHA performs better
         # than MHA with relative positional encoding.
         # than MHA with relative positional encoding.
         self_attn = self.build_adaptor_attention(
         self_attn = self.build_adaptor_attention(
-            self.w2v2_encoder_builder.config.num_encoder_attn_heads
+            self.config.w2v2_encoder_config.num_encoder_attn_heads
         )
         )
 
 
         conv = ConformerConvolution(
         conv = ConformerConvolution(
-            self.w2v2_encoder_builder.config.model_dim,
-            self.w2v2_encoder_builder.config.depthwise_conv_kernel_size,
+            self.config.w2v2_encoder_config.model_dim,
+            self.config.w2v2_encoder_config.depthwise_conv_kernel_size,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,
         )
         )
@@ -446,13 +480,13 @@ class NllbWithGELUBuilder(NllbBuilder):
     @override
     @override
     def build_ffn(self) -> FeedForwardNetwork:
     def build_ffn(self) -> FeedForwardNetwork:
         return StandardFeedForwardNetwork(
         return StandardFeedForwardNetwork(
-            self.config.model_dim,
-            self.config.ffn_inner_dim,
+            self._config.model_dim,
+            self._config.ffn_inner_dim,
             bias=True,
             bias=True,
             inner_activation=GELU(),
             inner_activation=GELU(),
             norm_order=TransformerNormOrder.PRE,
             norm_order=TransformerNormOrder.PRE,
-            device=self.device,
-            dtype=self.dtype,
+            device=self._device,
+            dtype=self._dtype,
         )
         )
 
 
 
 
@@ -497,11 +531,11 @@ def create_unity_model(
 
 
     if config.use_gelu:
     if config.use_gelu:
         mt_model_builder: NllbBuilder = NllbWithGELUBuilder(
         mt_model_builder: NllbBuilder = NllbWithGELUBuilder(
-            config.mt_model_config, device=device, dtype=dtype
+            UNITY_FAMILY, config.mt_model_config, device=device, dtype=dtype
         )
         )
     else:
     else:
         mt_model_builder = NllbBuilder(
         mt_model_builder = NllbBuilder(
-            config.mt_model_config, device=device, dtype=dtype
+            UNITY_FAMILY, config.mt_model_config, device=device, dtype=dtype
         )
         )
 
 
     unity_builder = UnitYBuilder(
     unity_builder = UnitYBuilder(

+ 8 - 33
src/seamless_communication/models/unity/char_tokenizer.py

@@ -4,6 +4,7 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
+from pathlib import Path
 from typing import Optional, Union, final
 from typing import Optional, Union, final
 
 
 from fairseq2.assets import (
 from fairseq2.assets import (
@@ -14,36 +15,24 @@ from fairseq2.assets import (
 )
 )
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data.text import (
 from fairseq2.data.text import (
-    SentencePieceDecoder,
+    SentencePieceTokenizer,
     SentencePieceEncoder,
     SentencePieceEncoder,
-    SentencePieceModel,
-    TextTokenDecoder,
-    TextTokenEncoder,
-    TextTokenizer,
-    vocab_info_from_sentencepiece,
 )
 )
-from fairseq2.data.typing import PathLike
-from fairseq2.typing import Device, finaloverride
+from fairseq2.typing import Device, override
 
 
 
 
 @final
 @final
-class CharTokenizer(TextTokenizer):
+class CharTokenizer(SentencePieceTokenizer):
     """A character-level tokenizer used during non-autoregressive T2U decoding."""
     """A character-level tokenizer used during non-autoregressive T2U decoding."""
 
 
-    model: SentencePieceModel
-
-    def __init__(self, pathname: PathLike) -> None:
+    def __init__(self, path: Path) -> None:
         """
         """
         :param pathname:
         :param pathname:
             The pathname of the SentencePiece model file.
             The pathname of the SentencePiece model file.
         """
         """
-        self.model = SentencePieceModel(pathname)
-
-        vocab_info = vocab_info_from_sentencepiece(self.model)
-
-        super().__init__(vocab_info)
+        super().__init__(path)
 
 
-    @finaloverride
+    @override
     def create_encoder(
     def create_encoder(
         self,
         self,
         task: Optional[str] = None,
         task: Optional[str] = None,
@@ -51,24 +40,10 @@ class CharTokenizer(TextTokenizer):
         mode: Optional[str] = None,
         mode: Optional[str] = None,
         device: Optional[Device] = None,
         device: Optional[Device] = None,
         pin_memory: bool = False,
         pin_memory: bool = False,
-    ) -> TextTokenEncoder:
+    ) -> SentencePieceEncoder:
         """Creates a character level encoder."""
         """Creates a character level encoder."""
-        return SentencePieceEncoder(
-            self.model,
-            device=device,
-            pin_memory=pin_memory,
-        )
-
-    @finaloverride
-    def create_raw_encoder(
-        self, *, device: Optional[Device] = None, pin_memory: bool = False
-    ) -> TextTokenEncoder:
         return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)
         return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)
 
 
-    @finaloverride
-    def create_decoder(self) -> TextTokenDecoder:
-        return SentencePieceDecoder(self.model)
-
 
 
 class UnitYCharTokenizerLoader:
 class UnitYCharTokenizerLoader:
     """Loads character-level tokenizers of UnitY models."""
     """Loads character-level tokenizers of UnitY models."""

+ 2 - 2
src/seamless_communication/models/unity/fft_decoder.py

@@ -10,7 +10,7 @@ from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.transformer import TransformerNormOrder, create_standard_layer_norm
 from fairseq2.nn.transformer import TransformerNormOrder, create_standard_layer_norm
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Module
 from torch.nn import Module
 
 
@@ -61,7 +61,7 @@ class FeedForwardTransformer(Module):
 
 
         self.norm_order = norm_order
         self.norm_order = norm_order
 
 
-    @finaloverride
+    @override
     def forward(
     def forward(
         self,
         self,
         seqs: Tensor,
         seqs: Tensor,

+ 3 - 3
src/seamless_communication/models/unity/fft_decoder_layer.py

@@ -9,7 +9,7 @@ from typing import Optional, Tuple, final
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
 from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Conv1d, Dropout, Module, ReLU
 from torch.nn import Conv1d, Dropout, Module, ReLU
 
 
@@ -71,7 +71,7 @@ class Conv1dBlock(Module):
             dtype=dtype,
             dtype=dtype,
         )
         )
 
 
-    @finaloverride
+    @override
     def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
     def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
         # Ensure that we do not leak padded positions in the convolution layer.
         # Ensure that we do not leak padded positions in the convolution layer.
         seqs = apply_padding_mask(seqs, padding_mask)
         seqs = apply_padding_mask(seqs, padding_mask)
@@ -173,7 +173,7 @@ class FeedForwardTransformerLayer(Module):
         else:
         else:
             self.register_module("film", None)
             self.register_module("film", None)
 
 
-    @finaloverride
+    @override
     def forward(
     def forward(
         self,
         self,
         seqs: Tensor,
         seqs: Tensor,

+ 78 - 76
src/seamless_communication/models/unity/loader.py

@@ -4,14 +4,14 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Dict, List, Mapping, Tuple, Union
+from typing import Any, Dict, List, Tuple, Union
 
 
 import torch
 import torch
-from fairseq2.assets import AssetStore, asset_store, download_manager
+from fairseq2.assets import AssetStore, asset_store
 from fairseq2.assets.card import AssetCard, AssetCardFieldNotFoundError
 from fairseq2.assets.card import AssetCard, AssetCardFieldNotFoundError
-from fairseq2.models.nllb import NllbConfig
-from fairseq2.models.nllb.loader import NllbTokenizerLoader
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models.nllb import NllbConfig, load_nllb_tokenizer
+from fairseq2.models import setup_model_family
+from fairseq2.data.text import register_text_tokenizer
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
 
 
 from seamless_communication.models.unity.builder import (
 from seamless_communication.models.unity.builder import (
@@ -20,13 +20,13 @@ from seamless_communication.models.unity.builder import (
     unity_archs,
     unity_archs,
 )
 )
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
-from seamless_communication.models.unity.model import UnitYModel
+from seamless_communication.models.unity.model import UNITY_FAMILY
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 
 
 
 
 def convert_unity_checkpoint(
 def convert_unity_checkpoint(
-    checkpoint: Mapping[str, Any], config: UnitYConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: UnitYConfig
+) -> Dict[str, Any]:
     state_dict = checkpoint["model"]
     state_dict = checkpoint["model"]
 
 
     # Check if we have a fairseq2 checkpoint.
     # Check if we have a fairseq2 checkpoint.
@@ -39,7 +39,11 @@ def convert_unity_checkpoint(
 
 
     state_dict = checkpoint["model"]
     state_dict = checkpoint["model"]
 
 
-    keys_to_delete = []
+    keys_to_delete = [
+        "speech_encoder_frontend.pos_encoder.conv.bias",
+        "speech_encoder_frontend.pos_encoder.conv.weight_g",
+        "speech_encoder_frontend.pos_encoder.conv.weight_v",
+    ]
 
 
     # ExpressiveUnitY model (from multi_arch codebase)
     # ExpressiveUnitY model (from multi_arch codebase)
     if config.prosody_encoder_config is not None:
     if config.prosody_encoder_config is not None:
@@ -203,42 +207,42 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
         # fmt: off
         # fmt: off
 
 
         # Speech Encoder
         # Speech Encoder
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    r"speech_encoder_frontend.pos_encoder.conv.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.":                                              r"speech_encoder_frontend.post_extract_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       r"speech_encoder_frontend.model_dim_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
-
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     r"speech_encoder.inner.layers.\1.conv.layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.inner.layers.\1.layer_norm.",
-        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     r"speech_encoder.inner.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    "speech_encoder_frontend.pos_encoder.conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.":                                              "speech_encoder_frontend.post_extract_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       "speech_encoder_frontend.model_dim_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             "speech_encoder_frontend.feature_extractor.layers.\\1.conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          "speech_encoder_frontend.feature_extractor.layers.\\1.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    "speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
+
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      "speech_encoder.inner.layers.\\1.conv.batch_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     "speech_encoder.inner.layers.\\1.conv.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  "speech_encoder.inner.layers.\\1.conv.depthwise_conv.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      "speech_encoder.inner.layers.\\1.conv_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": "speech_encoder.inner.layers.\\1.conv.pointwise_conv1.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": "speech_encoder.inner.layers.\\1.conv.pointwise_conv2.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         "speech_encoder.inner.layers.\\1.ffn\\2_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                "speech_encoder.inner.layers.\\1.ffn\\2.inner_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                "speech_encoder.inner.layers.\\1.ffn\\2.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         "speech_encoder.inner.layers.\\1.self_attn_layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          "speech_encoder.inner.layers.\\1.self_attn.q_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          "speech_encoder.inner.layers.\\1.self_attn.k_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          "speech_encoder.inner.layers.\\1.self_attn.v_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        "speech_encoder.inner.layers.\\1.self_attn.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            "speech_encoder.inner.layers.\\1.self_attn.q_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            "speech_encoder.inner.layers.\\1.self_attn.k_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            "speech_encoder.inner.layers.\\1.self_attn.v_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   "speech_encoder.inner.layers.\\1.self_attn.sdpa.rel_k_embed.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          "speech_encoder.inner.layers.\\1.self_attn.output_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        "speech_encoder.inner.layers.\\1.self_attn.sdpa.r_proj.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          "speech_encoder.inner.layers.\\1.self_attn.sdpa.u_bias",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          "speech_encoder.inner.layers.\\1.self_attn.sdpa.v_bias",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             "speech_encoder.inner.layers.\\1.layer_norm.",
+        fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     "speech_encoder.inner.layer_norm.",
 
 
         # Speech Encoder Adaptor
         # Speech Encoder Adaptor
-        fr"^{encoder_key}\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
-        fr"^{encoder_key}\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
-        fr"^{encoder_key}\.adaptor\.out_ln\.":  r"speech_encoder.layer_norm.",
+        fr"^{encoder_key}\.adaptor\.proj\.0\.": "speech_encoder.proj1.",
+        fr"^{encoder_key}\.adaptor\.proj\.2\.": "speech_encoder.proj2.",
+        fr"^{encoder_key}\.adaptor\.out_ln\.":  "speech_encoder.layer_norm.",
 
 
         # Text Encoder
         # Text Encoder
         r"^text_encoder\.embed_tokens\.":                              r"text_encoder_frontend.embed.",
         r"^text_encoder\.embed_tokens\.":                              r"text_encoder_frontend.embed.",
@@ -264,13 +268,13 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
     if config.w2v2_encoder_config.use_conformer:
     if config.w2v2_encoder_config.use_conformer:
         key_map.update(
         key_map.update(
             {
             {
-                fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
+                fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": "speech_encoder.inner_layer_norm."
             }
             }
         )
         )
     else:
     else:
         key_map.update(
         key_map.update(
             {
             {
-                rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
+                rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": "speech_encoder.inner.layer_norm."
             }
             }
         )
         )
     # fmt: on
     # fmt: on
@@ -279,20 +283,20 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
         key_map.update(
         key_map.update(
             {
             {
                 # fmt: off
                 # fmt: off
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    r"speech_encoder.adaptor_layers.\1.block.self_attn.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      r"speech_encoder.adaptor_layers.\1.layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 r"speech_encoder.adaptor_layers.\1.conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          "speech_encoder.adaptor_layers.\\1.block.self_attn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    "speech_encoder.adaptor_layers.\\1.block.self_attn.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         "speech_encoder.adaptor_layers.\\1.block.self_attn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         "speech_encoder.adaptor_layers.\\1.block.ffn\\2_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                "speech_encoder.adaptor_layers.\\1.block.ffn\\2.inner_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                "speech_encoder.adaptor_layers.\\1.block.ffn\\2.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      "speech_encoder.adaptor_layers.\\1.block.conv.batch_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  "speech_encoder.adaptor_layers.\\1.block.conv.depthwise_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      "speech_encoder.adaptor_layers.\\1.block.conv_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": "speech_encoder.adaptor_layers.\\1.block.conv.pointwise_conv1.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": "speech_encoder.adaptor_layers.\\1.block.conv.pointwise_conv2.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             "speech_encoder.adaptor_layers.\\1.block.layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      "speech_encoder.adaptor_layers.\\1.layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 "speech_encoder.adaptor_layers.\\1.conv.",
                 # fmt: on
                 # fmt: on
             }
             }
         )
         )
@@ -300,15 +304,15 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
         key_map.update(
         key_map.update(
             {
             {
                 # fmt: off
                 # fmt: off
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     r"speech_encoder.adaptor_layers.\1.residual_conv.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":            r"speech_encoder.adaptor_layers.\1.self_attn.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.":                  r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.":                  r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
-                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  "speech_encoder.adaptor_layers.\\1.residual_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     "speech_encoder.adaptor_layers.\\1.residual_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         "speech_encoder.adaptor_layers.\\1.self_attn_conv.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  "speech_encoder.adaptor_layers.\\1.self_attn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":            "speech_encoder.adaptor_layers.\\1.self_attn.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": "speech_encoder.adaptor_layers.\\1.self_attn_layer_norm.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.":                  "speech_encoder.adaptor_layers.\\1.ffn.inner_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.":                  "speech_encoder.adaptor_layers.\\1.ffn.output_proj.",
+                fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     "speech_encoder.adaptor_layers.\\1.ffn_layer_norm.",
                 # fmt: on
                 # fmt: on
             }
             }
         )
         )
@@ -389,20 +393,18 @@ def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
     return key_map
     return key_map
 
 
 
 
-load_unity_config = ConfigLoader[UnitYConfig](asset_store, unity_archs)
-
-
-load_unity_model = ModelLoader[UnitYModel, UnitYConfig](
-    asset_store,
-    download_manager,
-    load_unity_config,
+load_unity_model, load_unity_config = setup_model_family(
+    UNITY_FAMILY,
+    UnitYConfig,
     create_unity_model,
     create_unity_model,
+    unity_archs,
     convert_unity_checkpoint,
     convert_unity_checkpoint,
     restrict_checkpoints=False,
     restrict_checkpoints=False,
 )
 )
 
 
+load_unity_text_tokenizer = load_nllb_tokenizer
 
 
-load_unity_text_tokenizer = NllbTokenizerLoader(asset_store, download_manager)
+register_text_tokenizer(UNITY_FAMILY, load_unity_text_tokenizer)
 
 
 
 
 class UnitYUnitTokenizerLoader:
 class UnitYUnitTokenizerLoader:

+ 13 - 8
src/seamless_communication/models/unity/model.py

@@ -5,7 +5,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Optional, Tuple, Union, final
+from typing import Final, Optional, Tuple, Union, final
 
 
 from fairseq2.data import VocabularyInfo
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.encoder_decoder import EncoderDecoderModel
 from fairseq2.models.encoder_decoder import EncoderDecoderModel
@@ -23,6 +23,8 @@ from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
 from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 
 
+UNITY_FAMILY: Final = "unity"
+
 
 
 @final
 @final
 class UnitYModel(EncoderDecoderModel):
 class UnitYModel(EncoderDecoderModel):
@@ -55,13 +57,14 @@ class UnitYModel(EncoderDecoderModel):
         text_decoder: Optional[TransformerDecoder],
         text_decoder: Optional[TransformerDecoder],
         final_proj: Optional[Projection],
         final_proj: Optional[Projection],
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
+        max_target_seq_len: int,
         target_vocab_info: VocabularyInfo,
         target_vocab_info: VocabularyInfo,
         prosody_encoder_model: Optional[ECAPA_TDNN] = None,
         prosody_encoder_model: Optional[ECAPA_TDNN] = None,
         input_modality: str = "speech",
         input_modality: str = "speech",
     ) -> None:
     ) -> None:
         model_dim = speech_encoder.model_dim
         model_dim = speech_encoder.model_dim
 
 
-        super().__init__(model_dim, target_vocab_info)
+        super().__init__(UNITY_FAMILY, model_dim, max_target_seq_len, target_vocab_info)
 
 
         self.input_modality = input_modality
         self.input_modality = input_modality
 
 
@@ -190,7 +193,7 @@ class UnitYModel(EncoderDecoderModel):
 
 
         logits = self.final_proj(decoder_output)
         logits = self.final_proj(decoder_output)
 
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 
 
 @final
 @final
@@ -209,11 +212,12 @@ class UnitYX2TModel(EncoderDecoderModel):
         decoder_frontend: TransformerFrontend,
         decoder_frontend: TransformerFrontend,
         decoder: TransformerDecoder,
         decoder: TransformerDecoder,
         final_proj: Projection,
         final_proj: Projection,
+        max_target_seq_len: int,
         target_vocab_info: VocabularyInfo,
         target_vocab_info: VocabularyInfo,
     ) -> None:
     ) -> None:
         model_dim = encoder.model_dim
         model_dim = encoder.model_dim
 
 
-        super().__init__(model_dim, target_vocab_info)
+        super().__init__(UNITY_FAMILY, model_dim, max_target_seq_len, target_vocab_info)
 
 
         self.encoder_frontend = encoder_frontend
         self.encoder_frontend = encoder_frontend
         self.encoder = encoder
         self.encoder = encoder
@@ -257,7 +261,7 @@ class UnitYX2TModel(EncoderDecoderModel):
     ) -> SequenceModelOutput:
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
         logits = self.final_proj(decoder_output)
 
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 
 
 @final
 @final
@@ -276,9 +280,10 @@ class UnitYT2UModel(EncoderDecoderModel):
         decoder_frontend: TransformerFrontend,
         decoder_frontend: TransformerFrontend,
         decoder: TransformerDecoder,
         decoder: TransformerDecoder,
         final_proj: Projection,
         final_proj: Projection,
+        max_target_seq_len: int,
         target_vocab_info: VocabularyInfo,
         target_vocab_info: VocabularyInfo,
     ) -> None:
     ) -> None:
-        super().__init__(decoder.model_dim, target_vocab_info)
+        super().__init__(UNITY_FAMILY, decoder.model_dim, max_target_seq_len, target_vocab_info)
 
 
         if encoder is not None:
         if encoder is not None:
             self.encoder = encoder
             self.encoder = encoder
@@ -324,7 +329,7 @@ class UnitYT2UModel(EncoderDecoderModel):
     ) -> SequenceModelOutput:
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
         logits = self.final_proj(decoder_output)
 
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 
 
 @final
 @final
@@ -438,7 +443,7 @@ class UnitYNART2UModel(Module):
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
         logits = self.final_proj(decoder_output)
 
 
-        return SequenceModelOutput(logits, self.target_vocab_info)
+        return SequenceModelOutput(logits, self.target_vocab_info.pad_idx)
 
 
 
 
 @dataclass
 @dataclass

+ 2 - 2
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -15,7 +15,7 @@ from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder
 from fairseq2.nn.position_encoder import PositionEncoder
 from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.nn.transformer import create_standard_layer_norm
-from fairseq2.typing import DataType, Device, finaloverride
+from fairseq2.typing import DataType, Device, override
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Dropout, Module, Parameter
 from torch.nn import Dropout, Module, Parameter
 
 
@@ -296,7 +296,7 @@ class NARDecoderFrontend(Module):
 
 
         return seqs
         return seqs
 
 
-    @finaloverride
+    @override
     def forward(
     def forward(
         self,
         self,
         encoder_output: Tensor,
         encoder_output: Tensor,

+ 5 - 6
src/seamless_communication/models/unity/t2u_builder.py

@@ -6,15 +6,14 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Literal, Optional, Union
 from typing import Literal, Optional, Union
 
 
-from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import VocabularyInfo
 from fairseq2.data import VocabularyInfo
-from fairseq2.models.nllb.loader import NllbTokenizerLoader
+from fairseq2.models.nllb import load_nllb_tokenizer
 from fairseq2.models.transformer import (
 from fairseq2.models.transformer import (
     TransformerEmbeddingFrontend,
     TransformerEmbeddingFrontend,
     TransformerFrontend,
     TransformerFrontend,
 )
 )
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import Linear, Projection, TiedProjection
 from fairseq2.nn.projection import Linear, Projection, TiedProjection
@@ -131,8 +130,7 @@ class UnitYT2UConfig:
     """The dimensionality of prosody encoder (e.g. ECAPA_TDNN) output"""
     """The dimensionality of prosody encoder (e.g. ECAPA_TDNN) output"""
 
 
 
 
-unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
-
+unity_t2u_archs = ModelArchitectureRegistry[UnitYT2UConfig]()
 
 
 unity_t2u_arch = unity_t2u_archs.decorator
 unity_t2u_arch = unity_t2u_archs.decorator
 
 
@@ -329,6 +327,7 @@ class UnitYT2UBuilder:
             decoder_frontend,
             decoder_frontend,
             decoder,
             decoder,
             final_proj,
             final_proj,
+            self.config.unit_max_seq_len,
             self.config.target_vocab_info,
             self.config.target_vocab_info,
         )
         )
 
 
@@ -598,7 +597,7 @@ class UnitYNART2UBuilder:
             self.config.nar_decoder_frontend_config
             self.config.nar_decoder_frontend_config
         )
         )
 
 
-        nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
+        nllb_tokenizer = load_nllb_tokenizer(
             self.config.nar_decoder_config.model_name_or_card
             self.config.nar_decoder_config.model_name_or_card
         )
         )
 
 

+ 2 - 2
src/seamless_communication/models/vocoder/builder.py

@@ -7,7 +7,7 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Any, Dict, List, Optional
 from typing import Any, Dict, List, Optional
 
 
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.architecture_registry import ModelArchitectureRegistry
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
@@ -34,7 +34,7 @@ class VocoderConfig:
     lang_spkr_idx_map: Dict[str, Any]
     lang_spkr_idx_map: Dict[str, Any]
 
 
 
 
-vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_code_hifigan")
+vocoder_archs = ModelArchitectureRegistry[VocoderConfig]()
 
 
 vocoder_arch = vocoder_archs.decorator
 vocoder_arch = vocoder_archs.decorator
 
 

+ 9 - 13
src/seamless_communication/models/vocoder/loader.py

@@ -4,22 +4,21 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Mapping
+from typing import Any, Dict
 
 
-from fairseq2.assets import asset_store, download_manager
-from fairseq2.models.utils import ConfigLoader, ModelLoader
+from fairseq2.models import setup_model_family
 
 
 from seamless_communication.models.vocoder.builder import (
 from seamless_communication.models.vocoder.builder import (
     VocoderConfig,
     VocoderConfig,
     create_vocoder_model,
     create_vocoder_model,
     vocoder_archs,
     vocoder_archs,
 )
 )
-from seamless_communication.models.vocoder.vocoder import Vocoder
+from seamless_communication.models.vocoder.vocoder import VOCODER_CODE_HIFIGAN_FAMILY
 
 
 
 
 def convert_vocoder_checkpoint(
 def convert_vocoder_checkpoint(
-    checkpoint: Mapping[str, Any], config: VocoderConfig
-) -> Mapping[str, Any]:
+    checkpoint: Dict[str, Any], config: VocoderConfig
+) -> Dict[str, Any]:
     if (
     if (
         "model" in checkpoint
         "model" in checkpoint
         and "code_generator.resblocks.0.convs1.0.weight_g" in checkpoint["model"]
         and "code_generator.resblocks.0.convs1.0.weight_g" in checkpoint["model"]
@@ -36,13 +35,10 @@ def convert_vocoder_checkpoint(
     return checkpoint
     return checkpoint
 
 
 
 
-load_vocoder_config = ConfigLoader[VocoderConfig](asset_store, vocoder_archs)
-
-
-load_vocoder_model = ModelLoader[Vocoder, VocoderConfig](
-    asset_store,
-    download_manager,
-    load_vocoder_config,
+load_vocoder_model, load_vocoder_config = setup_model_family(
+    VOCODER_CODE_HIFIGAN_FAMILY,
+    VocoderConfig,
     create_vocoder_model,
     create_vocoder_model,
+    vocoder_archs,
     convert_vocoder_checkpoint,
     convert_vocoder_checkpoint,
 )
 )

+ 6 - 5
src/seamless_communication/models/vocoder/vocoder.py

@@ -4,21 +4,22 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Dict, Optional, List, Union
+from typing import Any, Dict, Optional, Final, List, Union
 import torch
 import torch
 from torch import Tensor
 from torch import Tensor
-from torch.nn import Module
+from fairseq2.models import Model
 
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 
 
+VOCODER_CODE_HIFIGAN_FAMILY: Final = "vocoder_code_hifigan"
 
 
-class Vocoder(Module):
+class Vocoder(Model):
     def __init__(
     def __init__(
         self,
         self,
         code_generator: CodeGenerator,
         code_generator: CodeGenerator,
         lang_spkr_idx_map: Dict[str, Any],
         lang_spkr_idx_map: Dict[str, Any],
     ):
     ):
-        super().__init__()
+        super().__init__(VOCODER_CODE_HIFIGAN_FAMILY)
         self.code_generator = code_generator
         self.code_generator = code_generator
         self.lang_spkr_idx_map = lang_spkr_idx_map
         self.lang_spkr_idx_map = lang_spkr_idx_map
 
 
@@ -29,7 +30,7 @@ class Vocoder(Module):
         spkr_list: Union[Optional[List[int]], int] = None,
         spkr_list: Union[Optional[List[int]], int] = None,
         dur_prediction: bool = True,
         dur_prediction: bool = True,
     ) -> Tensor:
     ) -> Tensor:
-        # TODO: Do we need this backward compatibility, or just update all calling sites? 
+        # TODO: Do we need this backward compatibility, or just update all calling sites?
         if len(units.shape) == 1:
         if len(units.shape) == 1:
             units = units.unsqueeze(0) # add batch dim
             units = units.unsqueeze(0) # add batch dim
         if isinstance(lang_list, str):
         if isinstance(lang_list, str):

+ 1 - 2
src/seamless_communication/toxicity/etox_bad_word_checker.py

@@ -16,7 +16,6 @@ from fairseq2.assets import (
     asset_store as base_asset_store,
     asset_store as base_asset_store,
     download_manager as base_download_manager,
     download_manager as base_download_manager,
 )
 )
-from fairseq2.data import StringLike
 from fairseq2.data.text import SentencePieceEncoder, SentencePieceModel
 from fairseq2.data.text import SentencePieceEncoder, SentencePieceModel
 
 
 
 
@@ -116,7 +115,7 @@ class ETOXBadWordChecker:
 
 
     @staticmethod
     @staticmethod
     def _contains_tokens(
     def _contains_tokens(
-        text_tokens: List[StringLike], word_tokens: List[StringLike]
+        text_tokens: List[str], word_tokens: List[str]
     ) -> bool:
     ) -> bool:
         for i in range(len(text_tokens) - len(word_tokens) + 1):
         for i in range(len(text_tokens) - len(word_tokens) + 1):
             for j in range(len(word_tokens)):
             for j in range(len(word_tokens)):

+ 7 - 8
src/seamless_communication/toxicity/mintox.py

@@ -18,7 +18,6 @@ from seamless_communication.toxicity.etox_bad_word_checker import (
 )
 )
 from fairseq2.generation import BannedSequenceProcessor
 from fairseq2.generation import BannedSequenceProcessor
 from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.text.text_tokenizer import TextTokenizer
-from fairseq2.data.typing import StringLike
 from fairseq2.typing import Device
 from fairseq2.typing import Device
 from fairseq2.data import SequenceData
 from fairseq2.data import SequenceData
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.nn.padding import get_seqs_and_padding_mask
@@ -32,8 +31,8 @@ logger = logging.getLogger(__name__)
 
 
 
 
 def _extract_bad_words_with_batch_indices(
 def _extract_bad_words_with_batch_indices(
-    source_texts: List[StringLike],
-    target_texts: List[StringLike],
+    source_texts: List[str],
+    target_texts: List[str],
     source_lang: str,
     source_lang: str,
     target_lang: str,
     target_lang: str,
     bad_word_checker: ETOXBadWordChecker,
     bad_word_checker: ETOXBadWordChecker,
@@ -54,9 +53,9 @@ def _extract_bad_words_with_batch_indices(
 
 
 
 
 def _replace_with_new_text_output_in_batch(
 def _replace_with_new_text_output_in_batch(
-    original_texts: List[StringLike],
+    original_texts: List[str],
     indices_with_toxicity: List[int],
     indices_with_toxicity: List[int],
-    new_texts: List[StringLike],
+    new_texts: List[str],
 ) -> None:
 ) -> None:
     new_idx = 0
     new_idx = 0
     # indices_with_toxicity is a small list, using list should be fast enough.
     # indices_with_toxicity is a small list, using list should be fast enough.
@@ -100,8 +99,8 @@ def mintox_pipeline(
     model_input: SequenceData,
     model_input: SequenceData,
     input_modality: "Modality",
     input_modality: "Modality",
     output_modality: "Modality",
     output_modality: "Modality",
-    src_texts: List[StringLike],
-    original_texts: List[StringLike],
+    src_texts: List[str],
+    original_texts: List[str],
     original_units: Optional[Tensor] = None,
     original_units: Optional[Tensor] = None,
     unit_generation_ngram_filtering: bool = False,
     unit_generation_ngram_filtering: bool = False,
     text_generation_opts: Optional[SequenceGeneratorOptions] = None,
     text_generation_opts: Optional[SequenceGeneratorOptions] = None,
@@ -109,7 +108,7 @@ def mintox_pipeline(
     bad_word_checker: ETOXBadWordChecker = None,
     bad_word_checker: ETOXBadWordChecker = None,
     duration_factor: float = 1.0,
     duration_factor: float = 1.0,
     prosody_encoder_input: Optional[SequenceData] = None,
     prosody_encoder_input: Optional[SequenceData] = None,
-) -> Tuple[List[StringLike], Optional[Tensor]]:
+) -> Tuple[List[str], Optional[Tensor]]:
     """MinTox: Mitigation at INference time of added TOXicity."""
     """MinTox: Mitigation at INference time of added TOXicity."""
     from seamless_communication.inference.translator import Modality, Translator
     from seamless_communication.inference.translator import Modality, Translator