Browse Source

Revise project structure (#82)

Can Balioglu 1 year ago
parent
commit
9daf4692ef
84 changed files with 323 additions and 405 deletions
  1. 3 3
      README.md
  2. 1 1
      demo/app.py
  3. 2 1
      dev_requirements.txt
  4. 32 0
      pyproject.toml
  5. 0 7
      requirements.txt
  6. 0 0
      scripts/install_devfair.sh
  7. 0 0
      scripts/install_fairaws.sh
  8. 16 50
      setup.py
  9. 15 0
      src/seamless_communication/__init__.py
  10. 0 9
      src/seamless_communication/assets/__init__.py
  11. 0 27
      src/seamless_communication/assets/download_manager.py
  12. 0 22
      src/seamless_communication/assets/store.py
  13. 0 0
      src/seamless_communication/cards/seamlessM4T_large.yaml
  14. 0 0
      src/seamless_communication/cards/seamlessM4T_medium.yaml
  15. 0 0
      src/seamless_communication/cards/seamlessM4T_v2_large.yaml
  16. 0 0
      src/seamless_communication/cards/unity_nllb-100.yaml
  17. 0 0
      src/seamless_communication/cards/unity_nllb-200.yaml
  18. 0 0
      src/seamless_communication/cards/vocoder_36langs.yaml
  19. 0 0
      src/seamless_communication/cards/vocoder_v2.yaml
  20. 0 0
      src/seamless_communication/cards/xlsr2_1b_v2.yaml
  21. 0 0
      src/seamless_communication/cli/__init__.py
  22. 0 0
      src/seamless_communication/cli/eval_utils/__init__.py
  23. 6 5
      src/seamless_communication/cli/eval_utils/compute_metrics.py
  24. 0 1
      src/seamless_communication/cli/eval_utils/lang_mapping.py
  25. 0 0
      src/seamless_communication/cli/m4t/__init__.py
  26. 0 0
      src/seamless_communication/cli/m4t/audio_to_units/README.md
  27. 0 0
      src/seamless_communication/cli/m4t/audio_to_units/__init__.py
  28. 2 1
      src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py
  29. 0 0
      src/seamless_communication/cli/m4t/evaluate/README.md
  30. 0 0
      src/seamless_communication/cli/m4t/evaluate/__init__.py
  31. 15 18
      src/seamless_communication/cli/m4t/evaluate/evaluate.py
  32. 0 0
      src/seamless_communication/cli/m4t/finetune/README.md
  33. 0 0
      src/seamless_communication/cli/m4t/finetune/__init__.py
  34. 0 0
      src/seamless_communication/cli/m4t/finetune/dataloader.py
  35. 1 1
      src/seamless_communication/cli/m4t/finetune/dataset.py
  36. 0 0
      src/seamless_communication/cli/m4t/finetune/dist_utils.py
  37. 1 1
      src/seamless_communication/cli/m4t/finetune/finetune.py
  38. 11 9
      src/seamless_communication/cli/m4t/finetune/trainer.py
  39. 0 0
      src/seamless_communication/cli/m4t/predict/README.md
  40. 4 2
      src/seamless_communication/cli/m4t/predict/__init__.py
  41. 6 9
      src/seamless_communication/cli/m4t/predict/predict.py
  42. 0 0
      src/seamless_communication/cli/m4t/train/__init__.py
  43. 3 3
      src/seamless_communication/cli/m4t/train/configs.py
  44. 10 7
      src/seamless_communication/cli/m4t/train/dataloader.py
  45. 0 0
      src/seamless_communication/cli/m4t/train/dist_utils.py
  46. 26 17
      src/seamless_communication/cli/m4t/train/model.py
  47. 0 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small.yaml
  48. 0 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc.yaml
  49. 0 0
      src/seamless_communication/cli/m4t/train/recipes/large_M4T_v1.yaml
  50. 0 0
      src/seamless_communication/cli/m4t/train/recipes/m4t_v1_train_manifests.txt
  51. 6 5
      src/seamless_communication/cli/m4t/train/run_training.py
  52. 0 1
      src/seamless_communication/cli/m4t/train/run_with_slurm.py
  53. 17 13
      src/seamless_communication/cli/m4t/train/trainer.py
  54. 16 0
      src/seamless_communication/inference/__init__.py
  55. 5 7
      src/seamless_communication/inference/generator.py
  56. 3 2
      src/seamless_communication/inference/ngram_repeat_block_processor.py
  57. 9 9
      src/seamless_communication/inference/translator.py
  58. 0 14
      src/seamless_communication/models/inference/__init__.py
  59. 5 5
      src/seamless_communication/models/unit_extractor/__init__.py
  60. 3 3
      src/seamless_communication/models/unit_extractor/kmeans.py
  61. 13 16
      src/seamless_communication/models/unit_extractor/unit_extractor.py
  62. 10 30
      src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py
  63. 9 12
      src/seamless_communication/models/unity/__init__.py
  64. 0 1
      src/seamless_communication/models/unity/adaptor_block.py
  65. 3 5
      src/seamless_communication/models/unity/builder.py
  66. 6 3
      src/seamless_communication/models/unity/char_tokenizer.py
  67. 5 7
      src/seamless_communication/models/unity/length_regulator.py
  68. 7 8
      src/seamless_communication/models/unity/loader.py
  69. 4 7
      src/seamless_communication/models/unity/nar_decoder.py
  70. 5 9
      src/seamless_communication/models/unity/nar_decoder_frontend.py
  71. 4 5
      src/seamless_communication/models/unity/nar_decoder_layer.py
  72. 14 16
      src/seamless_communication/models/unity/t2u_builder.py
  73. 1 1
      src/seamless_communication/models/vocoder/codehifigan.py
  74. 1 1
      src/seamless_communication/models/vocoder/loader.py
  75. 3 3
      src/seamless_communication/models/wav2vec2_chunk/__init__.py
  76. 2 2
      src/seamless_communication/models/wav2vec2_chunk/builder.py
  77. 2 3
      src/seamless_communication/models/wav2vec2_chunk/chunk_attention_mask.py
  78. 4 11
      src/seamless_communication/models/wav2vec2_chunk/encoder.py
  79. 0 0
      src/seamless_communication/py.typed
  80. 1 2
      tests/common.py
  81. 2 2
      tests/conftest.py
  82. 0 0
      tests/integration/inference/__init__.py
  83. 3 2
      tests/integration/inference/test_translator.py
  84. 6 6
      tests/integration/models/test_unit_extractor.py

+ 3 - 3
README.md

@@ -45,7 +45,7 @@ T2TT task:
 m4t_predict <input_text> t2tt <tgt_lang> --src_lang <src_lang>
 m4t_predict <input_text> t2tt <tgt_lang> --src_lang <src_lang>
 ```
 ```
 
 
-Please refer to the [inference README](scripts/m4t/predict) for detailed instruction on how to run inference and the list of supported languages on the source, target sides for speech, text modalities.
+Please refer to the [inference README](src/seamless_communication/cli/m4t/predict) for detailed instruction on how to run inference and the list of supported languages on the source, target sides for speech, text modalities.
 
 
 ## Running [Gradio](https://github.com/gradio-app/gradio) demo locally
 ## Running [Gradio](https://github.com/gradio-app/gradio) demo locally
 
 
@@ -86,10 +86,10 @@ We provide the extensive evaluation results of seamlessM4T-Large and SeamlessM4T
 To reproduce our results, or to evaluate using the same metrics over your own test sets, please check out the [README here](docs/m4t/eval_README.md).
 To reproduce our results, or to evaluate using the same metrics over your own test sets, please check out the [README here](docs/m4t/eval_README.md).
 
 
 ## Finetuning SeamlessM4T models
 ## Finetuning SeamlessM4T models
-Please check out the [README here](scripts/m4t/finetune/README.md).
+Please check out the [README here](src/seamless_communication/cli/m4t/finetune/README.md).
 
 
 ## Converting raw audio to units
 ## Converting raw audio to units
-Please check out the [README here](scripts/m4t/audio_to_units/README.md).
+Please check out the [README here](src/seamless_communication/cli/m4t/audio_to_units/README.md).
 
 
 ## On-device models
 ## On-device models
 Apart from Seamless-M4T large (2.3B) and medium (1.2B) models, we are also releasing a small model (281M) targeted for on-device inference. To learn more about the usage and model details check out the [README here](docs/m4t/on_device_README.md).
 Apart from Seamless-M4T large (2.3B) and medium (1.2B) models, we are also releasing a small model (281M) targeted for on-device inference. To learn more about the usage and model details check out the [README here](docs/m4t/on_device_README.md).

+ 1 - 1
demo/app.py

@@ -11,8 +11,8 @@ import numpy as np
 import torch
 import torch
 import torchaudio
 import torchaudio
 from huggingface_hub import hf_hub_download
 from huggingface_hub import hf_hub_download
-from seamless_communication.models.inference.translator import Translator
 
 
+from seamless_communication.models.inference.translator import Translator
 
 
 DESCRIPTION = """# SeamlessM4T
 DESCRIPTION = """# SeamlessM4T
 
 

+ 2 - 1
dev_requirements.txt

@@ -1,5 +1,6 @@
-pytest
 black
 black
 flake8
 flake8
 isort
 isort
 mypy
 mypy
+pre-commit
+pytest

+ 32 - 0
pyproject.toml

@@ -0,0 +1,32 @@
+[build-system]
+requires = ["packaging~=23.1", "setuptools~=67.8", "wheel~=0.40"]
+build-backend = "setuptools.build_meta"
+
+[tool.flake8]
+extend_ignore = ["E", "Y"]  # Black
+per-file-ignores = [
+    "__init__.py:F401",
+]
+
+[tool.isort]
+profile = "black"
+
+[tool.mypy]
+disable_error_code = "type-abstract"
+disallow_untyped_calls = false
+disallow_untyped_decorators = false
+ignore_missing_imports = true
+python_version = 3.8
+show_error_codes = true
+show_error_context = true
+strict = true
+warn_unused_configs = false
+warn_unused_ignores = false
+
+[tool.pytest.ini_options]
+minversion = "7.1"
+testpaths = ["tests"]
+filterwarnings = [
+    "ignore:torch.nn.utils.weight_norm is deprecated in favor of",
+    "ignore:TypedStorage is deprecated",
+]

+ 0 - 7
requirements.txt

@@ -1,7 +0,0 @@
-pre-commit
-datasets
-torchaudio
-tqdm
-soundfile
-librosa
-fairseq2==0.2.*

+ 0 - 0
scripts/m4t/train/install_devfair.sh → scripts/install_devfair.sh


+ 0 - 0
scripts/m4t/train/install_fairaws.sh → scripts/install_fairaws.sh


+ 16 - 50
setup.py

@@ -4,53 +4,14 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from pathlib import Path
-import os
-from typing import Iterable
-
-import pkg_resources
 from setuptools import find_packages, setup
 from setuptools import find_packages, setup
-from setuptools.command.develop import develop
-
-
-def _load_requirements(fname: str) -> Iterable[str]:
-    with open(Path(__file__).parent / fname) as fp_in:
-        for req in pkg_resources.parse_requirements(fp_in):
-            yield str(req)
-
-
-def _add_symlinks():
-    root = Path(__file__).parent
-    sc_root = root / "src/seamless_communication"
-    sc_link = root / "seamless_communication"
-    m4t_scripts_root = root / "scripts/m4t"
-    m4t_scripts_link = root / "m4t_scripts"
-    if not sc_link.exists():
-        os.symlink(sc_root, sc_link, target_is_directory=True)
-    if not m4t_scripts_link.exists():
-        os.symlink(m4t_scripts_root, m4t_scripts_link, target_is_directory=True)
-
-
-class cmd_for_editable_mode(develop):
-    def run(self):
-        # add symlinks for modules if install in editable mode
-        _add_symlinks()
-        super().run()
-
-
-default_requirements = list(_load_requirements("requirements.txt"))
-dev_requirements = list(_load_requirements("dev_requirements.txt"))
 
 
 setup(
 setup(
     name="seamless_communication",
     name="seamless_communication",
     version="1.0.0",
     version="1.0.0",
-    packages=find_packages(where="src")
-    + ["m4t_scripts.finetune", "m4t_scripts.predict"],
-    package_dir={
-        "m4t_scripts": "scripts/m4t",
-        "seamless_communication": "src/seamless_communication",
-    },
-    package_data={"": ["assets/cards/*.yaml"]},
+    packages=find_packages(where="src"),
+    package_dir={"": "src"},
+    package_data={"": ["py.typed", "cards/*.yaml"]},
     description="SeamlessM4T -- Massively Multilingual & Multimodal Machine Translation Model",
     description="SeamlessM4T -- Massively Multilingual & Multimodal Machine Translation Model",
     long_description=open("README.md", encoding="utf-8").read(),
     long_description=open("README.md", encoding="utf-8").read(),
     long_description_content_type="text/markdown",
     long_description_content_type="text/markdown",
@@ -59,17 +20,22 @@ setup(
     author="Fundamental AI Research (FAIR) at Meta",
     author="Fundamental AI Research (FAIR) at Meta",
     url="https://github.com/facebookresearch/seamless_communication",
     url="https://github.com/facebookresearch/seamless_communication",
     license="Creative Commons",
     license="Creative Commons",
-    install_requires=default_requirements,
-    extras_require={"dev": default_requirements + dev_requirements},
+    install_requires=[
+        "datasets",
+        "fairseq2==0.2.*",
+        "librosa",
+        "soundfile",
+        "torchaudio",
+        "tqdm",
+    ],
     entry_points={
     entry_points={
         "console_scripts": [
         "console_scripts": [
-            "m4t_evaluate=m4t_scripts.evaluate.evaluate:main",
-            "m4t_predict=m4t_scripts.predict.predict:main",
-            "m4t_finetune=m4t_scripts.finetune.finetune:main",
-            "m4t_prepare_dataset=m4t_scripts.finetune.dataset:main",
-            "m4t_audio_to_units=m4t_scripts.audio_to_units.audio_to_units:main",
+            "m4t_evaluate=seamless_communication.cli.m4t.evaluate.evaluate:main",
+            "m4t_predict=seamless_communication.cli.m4t.predict.predict:main",
+            "m4t_finetune=seamless_communication.cli.m4t.finetune.finetune:main",
+            "m4t_prepare_dataset=seamless_communication.cli.m4t.finetune.dataset:main",
+            "m4t_audio_to_units=seamless_communication.cli.m4t.audio_to_units.audio_to_units:main",
         ],
         ],
     },
     },
-    cmdclass={"develop": cmd_for_editable_mode},
     include_package_data=True,
     include_package_data=True,
 )
 )

+ 15 - 0
src/seamless_communication/__init__.py

@@ -4,4 +4,19 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+from pathlib import Path
+
+from fairseq2.assets import LocalAssetCardStorage, asset_store
+
 __version__ = "0.1.0"
 __version__ = "0.1.0"
+
+
+def _update_asset_store() -> None:
+    pathname = Path(__file__).parent.joinpath("cards")
+
+    card_storage = LocalAssetCardStorage(pathname)
+
+    asset_store.add_storage(card_storage)
+
+
+_update_asset_store()

+ 0 - 9
src/seamless_communication/assets/__init__.py

@@ -1,9 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-from seamless_communication.assets.download_manager import (
-    download_manager as download_manager,
-)
-from seamless_communication.assets.store import asset_store as asset_store

+ 0 - 27
src/seamless_communication/assets/download_manager.py

@@ -1,27 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from pathlib import Path
-
-import torch
-from fairseq2.assets import DefaultAssetDownloadManager
-
-
-class SCAssetDownloadManager(DefaultAssetDownloadManager):
-    @classmethod
-    def _get_pathname(cls, uri: str, sub_dir: str) -> Path:
-        hub_dir = Path(torch.hub.get_dir()).expanduser()
-
-        hsh = cls._get_uri_hash(uri)
-
-        filename = cls._get_filename(uri)
-
-        return hub_dir.joinpath(
-            "seamless_communication", "assets", sub_dir, hsh, filename
-        )
-
-
-download_manager = SCAssetDownloadManager()

+ 0 - 22
src/seamless_communication/assets/store.py

@@ -1,22 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from pathlib import Path
-
-from fairseq2.assets import AssetStore
-from fairseq2.assets.card_storage import LocalAssetCardStorage
-from fairseq2.assets.store import DefaultAssetStore
-
-
-def create_default_asset_store() -> AssetStore:
-    pathname = Path(__file__).parent.joinpath("cards")
-
-    card_storage = LocalAssetCardStorage(pathname)
-
-    return DefaultAssetStore(card_storage)
-
-
-asset_store = create_default_asset_store()

+ 0 - 0
src/seamless_communication/assets/cards/seamlessM4T_large.yaml → src/seamless_communication/cards/seamlessM4T_large.yaml


+ 0 - 0
src/seamless_communication/assets/cards/seamlessM4T_medium.yaml → src/seamless_communication/cards/seamlessM4T_medium.yaml


+ 0 - 0
src/seamless_communication/assets/cards/seamlessM4T_v2_large.yaml → src/seamless_communication/cards/seamlessM4T_v2_large.yaml


+ 0 - 0
src/seamless_communication/assets/cards/unity_nllb-100.yaml → src/seamless_communication/cards/unity_nllb-100.yaml


+ 0 - 0
src/seamless_communication/assets/cards/unity_nllb-200.yaml → src/seamless_communication/cards/unity_nllb-200.yaml


+ 0 - 0
src/seamless_communication/assets/cards/vocoder_36langs.yaml → src/seamless_communication/cards/vocoder_36langs.yaml


+ 0 - 0
src/seamless_communication/assets/cards/vocoder_v2.yaml → src/seamless_communication/cards/vocoder_v2.yaml


+ 0 - 0
src/seamless_communication/assets/cards/xlsr2_1b_v2.yaml → src/seamless_communication/cards/xlsr2_1b_v2.yaml


+ 0 - 0
scripts/m4t/train/__init__.py → src/seamless_communication/cli/__init__.py


+ 0 - 0
src/seamless_communication/cli/eval_utils/__init__.py


+ 6 - 5
scripts/eval_utils/compute_metrics.py → src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -4,17 +4,18 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from pathlib import Path
 import logging
 import logging
+from pathlib import Path
+from typing import Optional
+
 import pandas as pd
 import pandas as pd
 import sacrebleu
 import sacrebleu
 import whisper
 import whisper
-from jiwer import wer, cer
+from jiwer import cer, wer
 from tqdm import tqdm
 from tqdm import tqdm
-from typing import Optional
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
 
-from scripts.eval_utils.lang_mapping import LANG3_LANG2
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -316,7 +317,7 @@ def compute_quality_metrics(
             whisper_normalize_text=True,
             whisper_normalize_text=True,
         )
         )
         transcripts_df.to_csv(
         transcripts_df.to_csv(
-            (Path(output_dir) / f"whisper_audio_transcriptions.tsv"),
+            (Path(output_dir) / "whisper_audio_transcriptions.tsv"),
             sep="\t",
             sep="\t",
             index=False,
             index=False,
             encoding="utf-8",
             encoding="utf-8",

+ 0 - 1
scripts/eval_utils/lang_mapping.py → src/seamless_communication/cli/eval_utils/lang_mapping.py

@@ -174,4 +174,3 @@ LANG2_LANG3 = {
     "tk": "tuk",
     "tk": "tuk",
 }
 }
 LANG3_LANG2 = {v: k for k, v in LANG2_LANG3.items()}
 LANG3_LANG2 = {v: k for k, v in LANG2_LANG3.items()}
-

+ 0 - 0
src/seamless_communication/cli/m4t/__init__.py


+ 0 - 0
scripts/m4t/audio_to_units/README.md → src/seamless_communication/cli/m4t/audio_to_units/README.md


+ 0 - 0
scripts/m4t/audio_to_units/__init__.py → src/seamless_communication/cli/m4t/audio_to_units/__init__.py


+ 2 - 1
scripts/m4t/audio_to_units/audio_to_units.py → src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py

@@ -5,9 +5,10 @@
 
 
 import argparse
 import argparse
 import logging
 import logging
+
 import torch
 import torch
-from seamless_communication.models.unit_extraction import UnitExtractor
 
 
+from seamless_communication.models.unit_extractor import UnitExtractor
 
 
 logging.basicConfig(level=logging.INFO)
 logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)

+ 0 - 0
scripts/m4t/evaluate/README.md → src/seamless_communication/cli/m4t/evaluate/README.md


+ 0 - 0
scripts/m4t/evaluate/__init__.py → src/seamless_communication/cli/m4t/evaluate/__init__.py


+ 15 - 18
scripts/m4t/evaluate/evaluate.py → src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -9,33 +9,31 @@ import contextlib
 import itertools
 import itertools
 import logging
 import logging
 import subprocess
 import subprocess
-import torch
-import torchaudio
-
 from argparse import Namespace
 from argparse import Namespace
 from dataclasses import dataclass
 from dataclasses import dataclass
 from pathlib import Path
 from pathlib import Path
-from torch import Tensor
-from tqdm import tqdm
-from typing import List, Optional, Tuple, Dict
+from typing import Dict, List, Optional, Tuple
 
 
+import torch
+import torchaudio
 from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
 from fairseq2.data.typing import StringLike
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceGeneratorOptions
 from fairseq2.generation import SequenceGeneratorOptions
-from fairseq2.typing import Device, DataType
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+from tqdm import tqdm
 
 
-from m4t_scripts.predict import add_inference_arguments, set_generation_opts
-from seamless_communication.models.inference import (
-    BatchedSpeechOutput,
-    Modality,
-    Translator,
-)
-from seamless_communication.models.unity import load_unity_text_tokenizer
-from scripts.eval_utils.compute_metrics import (
+from seamless_communication.cli.eval_utils.compute_metrics import (
     compute_quality_metrics,
     compute_quality_metrics,
 )
 )
+from seamless_communication.cli.predict import (
+    add_inference_arguments,
+    set_generation_opts,
+)
+from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
+from seamless_communication.models.unity import load_unity_text_tokenizer
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -247,9 +245,9 @@ def run_eval(
     ) as unit_file:
     ) as unit_file:
         sample_id = 0
         sample_id = 0
         if ctx.output_modality == Modality.SPEECH:
         if ctx.output_modality == Modality.SPEECH:
-            hyp_file.write(f"ref_tgt_text\tpred_tgt_text\tpred_tgt_audio\n")
+            hyp_file.write("ref_tgt_text\tpred_tgt_text\tpred_tgt_audio\n")
         else:
         else:
-            hyp_file.write(f"ref_tgt_text\tpred_tgt_text\n")
+            hyp_file.write("ref_tgt_text\tpred_tgt_text\n")
         for example in pipeline:
         for example in pipeline:
             valid_sequences: Optional[Tensor] = None
             valid_sequences: Optional[Tensor] = None
             if ctx.input_modality == Modality.SPEECH:
             if ctx.input_modality == Modality.SPEECH:
@@ -302,7 +300,6 @@ def run_eval(
             refs = [str(s) for s in example[ctx.ref_field]]
             refs = [str(s) for s in example[ctx.ref_field]]
 
 
             for i in range(len(text_output)):
             for i in range(len(text_output)):
-                t = text_output[i]
                 if ctx.output_modality == Modality.SPEECH:
                 if ctx.output_modality == Modality.SPEECH:
                     assert speech_output is not None
                     assert speech_output is not None
                     u = speech_output.units[i]
                     u = speech_output.units[i]

+ 0 - 0
scripts/m4t/finetune/README.md → src/seamless_communication/cli/m4t/finetune/README.md


+ 0 - 0
scripts/m4t/finetune/__init__.py → src/seamless_communication/cli/m4t/finetune/__init__.py


+ 0 - 0
scripts/m4t/finetune/dataloader.py → src/seamless_communication/cli/m4t/finetune/dataloader.py


+ 1 - 1
scripts/m4t/finetune/dataset.py → src/seamless_communication/cli/m4t/finetune/dataset.py

@@ -18,7 +18,7 @@ from seamless_communication.datasets.huggingface import (
     Speech2SpeechFleursDatasetBuilder,
     Speech2SpeechFleursDatasetBuilder,
     SpeechTokenizer,
     SpeechTokenizer,
 )
 )
-from seamless_communication.models.unit_extraction import UnitExtractor
+from seamless_communication.models.unit_extractor import UnitExtractor
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,

+ 0 - 0
scripts/m4t/finetune/dist_utils.py → src/seamless_communication/cli/m4t/finetune/dist_utils.py


+ 1 - 1
scripts/m4t/finetune/finetune.py → src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -11,8 +11,8 @@ from pathlib import Path
 
 
 import torch
 import torch
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
-from m4t_scripts.finetune import dataloader, dist_utils, trainer
 
 
+from seamless_communication.cli.m4t.finetune import dataloader, dist_utils, trainer
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitTokenizer,
     UnitYModel,
     UnitYModel,

+ 11 - 9
scripts/m4t/finetune/trainer.py → src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -15,12 +15,14 @@ from typing import Optional, Tuple
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
 import torch.nn as nn
 import torch.nn as nn
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.sequence import SequenceModelOutput
+from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
 from fairseq2.typing import Device
-from m4t_scripts.finetune import dataloader, dist_utils
 from torch.optim import Adam
 from torch.optim import Adam
 
 
+from seamless_communication.cli.finetune import dataloader, dist_utils
 from seamless_communication.models.unity import UnitYModel
 from seamless_communication.models.unity import UnitYModel
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -136,12 +138,12 @@ class CalcLoss:
     def __init__(
     def __init__(
         self,
         self,
         label_smoothing: float,
         label_smoothing: float,
-        s2t_pad_idx: Optional[int],
-        t2u_pad_idx: Optional[int],
+        s2t_vocab_info: VocabularyInfo,
+        t2u_vocab_info: VocabularyInfo,
     ):
     ):
         self.label_smoothing = label_smoothing
         self.label_smoothing = label_smoothing
-        self.s2t_pad_idx = s2t_pad_idx
-        self.t2u_pad_idx = t2u_pad_idx
+        self.s2t_vocab_info = s2t_vocab_info
+        self.t2u_vocab_info = t2u_vocab_info
 
 
     def __call__(
     def __call__(
         self,
         self,
@@ -154,7 +156,7 @@ class CalcLoss:
             text_logits.device
             text_logits.device
         )
         )
         s2t_loss = SequenceModelOutput(
         s2t_loss = SequenceModelOutput(
-            logits=text_logits, pad_idx=self.s2t_pad_idx
+            logits=text_logits, vocab_info=self.s2t_vocab_info
         ).compute_loss(
         ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             ignore_prefix_size=1,
             ignore_prefix_size=1,
@@ -165,7 +167,7 @@ class CalcLoss:
         assert batch.text_to_units.target_lengths is not None
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
         s2u_loss = SequenceModelOutput(
-            logits=unit_logits, pad_idx=self.t2u_pad_idx
+            logits=unit_logits, vocab_info=self.t2u_vocab_info
         ).compute_loss(
         ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             ignore_prefix_size=1,
             ignore_prefix_size=1,
@@ -227,8 +229,8 @@ class UnitYFinetune:
         assert model.t2u_model is not None
         assert model.t2u_model is not None
         self.calc_loss = CalcLoss(
         self.calc_loss = CalcLoss(
             label_smoothing=self.params.label_smoothing,
             label_smoothing=self.params.label_smoothing,
-            s2t_pad_idx=model.pad_idx,
-            t2u_pad_idx=model.t2u_model.pad_idx,
+            s2t_vocab_info=model.target_vocab_info,
+            t2u_vocab_info=model.t2u_model.target_vocab_info,
         )
         )
         self.model = self._wrap_model_for_trainining(model=model)
         self.model = self._wrap_model_for_trainining(model=model)
         self.train_data_loader = train_data_loader
         self.train_data_loader = train_data_loader

+ 0 - 0
scripts/m4t/predict/README.md → src/seamless_communication/cli/m4t/predict/README.md


+ 4 - 2
scripts/m4t/predict/__init__.py → src/seamless_communication/cli/m4t/predict/__init__.py

@@ -4,7 +4,9 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from m4t_scripts.predict.predict import (
+from seamless_communication.cli.m4t.predict.predict import (
     add_inference_arguments as add_inference_arguments,
     add_inference_arguments as add_inference_arguments,
 )
 )
-from m4t_scripts.predict.predict import set_generation_opts as set_generation_opts
+from seamless_communication.cli.m4t.predict.predict import (
+    set_generation_opts as set_generation_opts,
+)

+ 6 - 9
scripts/m4t/predict/predict.py → src/seamless_communication/cli/m4t/predict/predict.py

@@ -5,17 +5,14 @@
 
 
 import argparse
 import argparse
 import logging
 import logging
+from argparse import Namespace
+from typing import Tuple
+
 import torch
 import torch
 import torchaudio
 import torchaudio
-
-from argparse import Namespace
 from fairseq2.generation import SequenceGeneratorOptions
 from fairseq2.generation import SequenceGeneratorOptions
-from seamless_communication.models.inference import (
-    NGramRepeatBlockProcessor,
-    Translator,
-)
-from typing import Tuple
 
 
+from seamless_communication.inference import NGramRepeatBlockProcessor, Translator
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -152,7 +149,7 @@ def set_generation_opts(
         ),
         ),
     )
     )
     if args.text_generation_ngram_blocking:
     if args.text_generation_ngram_blocking:
-        text_generation_opts.logits_processor = NGramRepeatBlockProcessor(
+        text_generation_opts.step_processor = NGramRepeatBlockProcessor(
             no_repeat_ngram_size=args.no_repeat_ngram_size
             no_repeat_ngram_size=args.no_repeat_ngram_size
         )
         )
 
 
@@ -164,7 +161,7 @@ def set_generation_opts(
         ),
         ),
     )
     )
     if args.unit_generation_ngram_blocking:
     if args.unit_generation_ngram_blocking:
-        unit_generation_opts.logits_processor = NGramRepeatBlockProcessor(
+        unit_generation_opts.step_processor = NGramRepeatBlockProcessor(
             no_repeat_ngram_size=args.no_repeat_ngram_size
             no_repeat_ngram_size=args.no_repeat_ngram_size
         )
         )
     return text_generation_opts, unit_generation_opts
     return text_generation_opts, unit_generation_opts

+ 0 - 0
src/seamless_communication/cli/m4t/train/__init__.py


+ 3 - 3
scripts/m4t/train/configs.py → src/seamless_communication/cli/m4t/train/configs.py

@@ -4,10 +4,10 @@
 # This source code is licensed under the BSD-style license found in the
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-import yaml
-
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Dict, Any, Union, get_origin, get_args, List, Literal, Optional
+from typing import Any, Dict, List, Literal, Optional, Union, get_args, get_origin
+
+import yaml
 
 
 
 
 @dataclass
 @dataclass

+ 10 - 7
scripts/m4t/train/dataloader.py → src/seamless_communication/cli/m4t/train/dataloader.py

@@ -5,15 +5,12 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 
 
+import ctypes
 import logging
 import logging
 import os
 import os
 from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
 from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
-import ctypes
 
 
 import torch
 import torch
-from m4t_scripts.train.configs import AudioProcessingConfig, DataLoadingConfig
-from torch import Tensor
-
 from fairseq2.data import (
 from fairseq2.data import (
     CollateOptionsOverride,
     CollateOptionsOverride,
     Collater,
     Collater,
@@ -24,6 +21,12 @@ from fairseq2.data import (
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import SentencePieceEncoder, StrSplitter, read_text
 from fairseq2.data.text import SentencePieceEncoder, StrSplitter, read_text
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from torch import Tensor
+
+from seamless_communication.cli.m4t.train.configs import (
+    AudioProcessingConfig,
+    DataLoadingConfig,
+)
 from seamless_communication.models.tokenizer import SPMTokenizer
 from seamless_communication.models.tokenizer import SPMTokenizer
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitTokenizer,
@@ -419,15 +422,15 @@ class UnityDataLoader:
             overrides=[
             overrides=[
                 CollateOptionsOverride(
                 CollateOptionsOverride(
                     selector=f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}.data.fbank",
                     selector=f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}.data.fbank",
-                    pad_idx=self.config.fbank_feats_pad_idx,
+                    pad_value=self.config.fbank_feats_pad_idx,
                 ),
                 ),
                 CollateOptionsOverride(
                 CollateOptionsOverride(
                     selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
                     selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
-                    pad_idx=self.text_tokenizer.vocab_info.pad_idx,
+                    pad_value=self.text_tokenizer.vocab_info.pad_idx,
                 ),
                 ),
                 CollateOptionsOverride(
                 CollateOptionsOverride(
                     selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
                     selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
-                    pad_idx=self.unit_tokenizer.vocab_info.pad_idx,
+                    pad_value=self.unit_tokenizer.vocab_info.pad_idx,
                 ),
                 ),
             ],
             ],
         )
         )

+ 0 - 0
scripts/m4t/train/dist_utils.py → src/seamless_communication/cli/m4t/train/dist_utils.py


+ 26 - 17
scripts/m4t/train/model.py → src/seamless_communication/cli/m4t/train/model.py

@@ -7,27 +7,26 @@
 
 
 import logging
 import logging
 import os
 import os
-from typing import Dict, Any
+from typing import Any, Dict
 
 
 import torch
 import torch
-from m4t_scripts.train.configs import CustomModelParams, ModelConfig
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.nllb.builder import NllbConfig
+from fairseq2.models.nllb.loader import NllbLoader
+from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
+from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
+from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
+from fairseq2.nn.transformer import TransformerNormOrder
 
 
+from seamless_communication.cli.m4t.train.configs import CustomModelParams, ModelConfig
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     UnitYConfig,
     UnitYConfig,
     UnitYModel,
     UnitYModel,
-    load_unity_model,
+    UnitYT2UConfig,
     create_unity_model,
     create_unity_model,
+    load_unity_model,
 )
 )
-from seamless_communication.models.unity.loader import load_unity_config
-from seamless_communication.models.unity import UnitYT2UConfig
-from fairseq2.nn.transformer import TransformerNormOrder
-from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
-from fairseq2.models.nllb.builder import NllbConfig
-from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
-from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
-from seamless_communication.models.unity.loader import UnitYLoader
-
-from fairseq2.models.nllb.loader import NllbLoader
+from seamless_communication.models.unity.loader import UnitYLoader, load_unity_config
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -257,8 +256,13 @@ class ModelBuilder:
             mt_model_config=NllbConfig(
             mt_model_config=NllbConfig(
                 model_dim=config.model_embed_dim,
                 model_dim=config.model_embed_dim,
                 max_seq_len=1024,
                 max_seq_len=1024,
-                vocabulary_size=config.nllb_vocabulary_size,  # num_tokens + langs + spec symbols
-                pad_idx=0,
+                vocab_info=VocabularyInfo(
+                    size=config.nllb_vocabulary_size,
+                    unk_idx=1,
+                    bos_idx=2,
+                    eos_idx=3,
+                    pad_idx=0,
+                ),
                 num_encoder_layers=config.nllb_encoder_layers,
                 num_encoder_layers=config.nllb_encoder_layers,
                 num_decoder_layers=config.nllb_decoder_layers,
                 num_decoder_layers=config.nllb_decoder_layers,
                 num_encoder_attn_heads=16,
                 num_encoder_attn_heads=16,
@@ -269,8 +273,13 @@ class ModelBuilder:
             t2u_config=UnitYT2UConfig(
             t2u_config=UnitYT2UConfig(
                 model_dim=config.model_embed_dim,
                 model_dim=config.model_embed_dim,
                 unit_max_seq_len=2048,
                 unit_max_seq_len=2048,
-                unit_vocabulary_size=config.unit_vocabulary_size,
-                unit_pad_idx=1,
+                target_vocab_info=VocabularyInfo(
+                    size=config.unit_vocabulary_size,
+                    unk_idx=3,
+                    bos_idx=0,
+                    eos_idx=2,
+                    pad_idx=1,
+                ),
                 num_encoder_layers=config.t2u_encoder_layers,
                 num_encoder_layers=config.t2u_encoder_layers,
                 num_decoder_layers=config.t2u_decoder_layers,
                 num_decoder_layers=config.t2u_decoder_layers,
                 nar_decoder_frontend_config=None,
                 nar_decoder_frontend_config=None,

+ 0 - 0
scripts/m4t/train/recipes/asr_small.yaml → src/seamless_communication/cli/m4t/train/recipes/asr_small.yaml


+ 0 - 0
scripts/m4t/train/recipes/asr_small_wh_transc.yaml → src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc.yaml


+ 0 - 0
scripts/m4t/train/recipes/large_M4T_v1.yaml → src/seamless_communication/cli/m4t/train/recipes/large_M4T_v1.yaml


+ 0 - 0
scripts/m4t/train/recipes/m4t_v1_train_manifests.txt → src/seamless_communication/cli/m4t/train/recipes/m4t_v1_train_manifests.txt


+ 6 - 5
scripts/m4t/train/run_training.py → src/seamless_communication/cli/m4t/train/run_training.py

@@ -15,11 +15,12 @@ from typing import List
 
 
 import torch
 import torch
 import yaml
 import yaml
-from m4t_scripts.train import dataloader as _dataloader
-from m4t_scripts.train import dist_utils
-from m4t_scripts.train import model as _model
-from m4t_scripts.train import trainer as _trainer
-from m4t_scripts.train.configs import WorkflowParams
+
+from seamless_communication.cli.m4t.train import dataloader as _dataloader
+from seamless_communication.cli.m4t.train import dist_utils
+from seamless_communication.cli.m4t.train import model as _model
+from seamless_communication.cli.m4t.train import trainer as _trainer
+from seamless_communication.cli.m4t.train.configs import WorkflowParams
 
 
 logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
 logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
 logging.basicConfig(
 logging.basicConfig(

+ 0 - 1
scripts/m4t/train/run_with_slurm.py → src/seamless_communication/cli/m4t/train/run_with_slurm.py

@@ -7,7 +7,6 @@ import subprocess
 import time
 import time
 from pathlib import Path
 from pathlib import Path
 
 
-
 logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
 logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,

+ 17 - 13
scripts/m4t/train/trainer.py → src/seamless_communication/cli/m4t/train/trainer.py

@@ -6,21 +6,22 @@
 
 
 
 
 import logging
 import logging
-from typing import Any, Optional, Tuple, Dict, List
-
 import os
 import os
 import time
 import time
+from typing import Any, Dict, List, Optional, Tuple
+
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
 import torch.nn as nn
 import torch.nn as nn
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.optim.lr_scheduler import MyleLR
-from m4t_scripts.train import dataloader, dist_utils
 from torch.optim import Adam
 from torch.optim import Adam
 
 
+from seamless_communication.cli.m4t.train import dataloader, dist_utils
+from seamless_communication.cli.m4t.train.configs import TrainingParams
 from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
 from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
-from m4t_scripts.train.configs import TrainingParams
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -67,7 +68,10 @@ class UnitYTrainWrapper(nn.Module):
         )
         )
         text_logits = self.model.final_proj(text_decoder_out)
         text_logits = self.model.final_proj(text_decoder_out)
         # t2u
         # t2u
-        (unit_encoder_out, unit_encoder_padding_mask,) = self.t2u.encode(
+        (
+            unit_encoder_out,
+            unit_encoder_padding_mask,
+        ) = self.t2u.encode(
             text_decoder_output=text_decoder_out,
             text_decoder_output=text_decoder_out,
             text_decoder_padding_mask=text_decoder_padding_mask,
             text_decoder_padding_mask=text_decoder_padding_mask,
         )
         )
@@ -91,13 +95,13 @@ class CalcLoss:
     def __init__(
     def __init__(
         self,
         self,
         label_smoothing: float,
         label_smoothing: float,
-        s2t_pad_idx: Optional[int],
-        t2u_pad_idx: Optional[int],
+        s2t_vocab_info: VocabularyInfo,
+        t2u_vocab_info: VocabularyInfo,
         s2t_skip_langtok_loss: bool = False,
         s2t_skip_langtok_loss: bool = False,
     ):
     ):
         self.label_smoothing = label_smoothing
         self.label_smoothing = label_smoothing
-        self.s2t_pad_idx = s2t_pad_idx
-        self.t2u_pad_idx = t2u_pad_idx
+        self.s2t_vocab_info = s2t_vocab_info
+        self.t2u_vocab_info = t2u_vocab_info
         self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
         self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
         self.t2u_ignore_prefix_size = 1
         self.t2u_ignore_prefix_size = 1
 
 
@@ -112,7 +116,7 @@ class CalcLoss:
             text_logits.device
             text_logits.device
         )
         )
         s2t_loss = SequenceModelOutput(
         s2t_loss = SequenceModelOutput(
-            logits=text_logits, pad_idx=self.s2t_pad_idx
+            logits=text_logits, vocab_info=self.s2t_vocab_info
         ).compute_loss(
         ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             ignore_prefix_size=self.s2t_ignore_prefix_size,
             ignore_prefix_size=self.s2t_ignore_prefix_size,
@@ -121,7 +125,7 @@ class CalcLoss:
         assert batch.text_to_units.target_lengths is not None
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
         s2u_loss = SequenceModelOutput(
-            logits=unit_logits, pad_idx=self.t2u_pad_idx
+            logits=unit_logits, vocab_info=self.t2u_vocab_info
         ).compute_loss(
         ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             ignore_prefix_size=1,
             ignore_prefix_size=1,
@@ -192,8 +196,8 @@ class UnitYTrainer:
         assert model.t2u_model is not None
         assert model.t2u_model is not None
         self.calc_loss = CalcLoss(
         self.calc_loss = CalcLoss(
             label_smoothing=self.params.label_smoothing,
             label_smoothing=self.params.label_smoothing,
-            s2t_pad_idx=model.pad_idx,
-            t2u_pad_idx=model.t2u_model.pad_idx,
+            s2t_vocab_info=model.target_vocab_info,
+            t2u_vocab_info=model.t2u_model.target_vocab_info,
         )
         )
         self._try_load_checkpoint(model=model)
         self._try_load_checkpoint(model=model)
         self.model = self._wrap_model_for_trainining(model=model)
         self.model = self._wrap_model_for_trainining(model=model)

+ 16 - 0
src/seamless_communication/inference/__init__.py

@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
+from seamless_communication.inference.ngram_repeat_block_processor import (
+    NGramRepeatBlockProcessor as NGramRepeatBlockProcessor,
+)
+from seamless_communication.inference.translator import (
+    BatchedSpeechOutput as BatchedSpeechOutput,
+)
+from seamless_communication.inference.translator import Modality as Modality
+from seamless_communication.inference.translator import Task as Task
+from seamless_communication.inference.translator import Translator as Translator

+ 5 - 7
src/seamless_communication/models/unity/generator.py → src/seamless_communication/inference/generator.py

@@ -5,12 +5,9 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 
 import torch
 import torch
-
-from torch import Tensor
-from fairseq2.data import VocabularyInfo
 from fairseq2.data.text import TextTokenizer
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
 from fairseq2.generation import (
     Seq2SeqGenerator,
     Seq2SeqGenerator,
@@ -21,11 +18,12 @@ from fairseq2.generation import (
 )
 )
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.utils.module import infer_device
 from fairseq2.nn.utils.module import infer_device
+from torch import Tensor
 
 
 from seamless_communication.models.unity.model import (
 from seamless_communication.models.unity.model import (
     UnitYModel,
     UnitYModel,
-    UnitYX2TModel,
     UnitYT2UModel,
     UnitYT2UModel,
+    UnitYX2TModel,
 )
 )
 from seamless_communication.models.unity.unit_tokenizer import (
 from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenDecoder,
     UnitTokenDecoder,
@@ -35,7 +33,7 @@ from seamless_communication.models.unity.unit_tokenizer import (
 
 
 def remove_consecutive_repeated_ngrams(
 def remove_consecutive_repeated_ngrams(
     sequence: List[int], min_size: int = 1, max_size: int = 40
     sequence: List[int], min_size: int = 1, max_size: int = 40
-):
+) -> List[int]:
     assert 1 <= min_size <= max_size
     assert 1 <= min_size <= max_size
     drop_idx = set()  # indices that will be dropped from the sequence
     drop_idx = set()  # indices that will be dropped from the sequence
 
 
@@ -188,7 +186,7 @@ class UnitYGenerator:
             )
             )
         elif input_modality == "text" and self.t2t_generator is None:
         elif input_modality == "text" and self.t2t_generator is None:
             raise ValueError(
             raise ValueError(
-                f"Please set use_text_encoder to True in your model config to encode text."
+                "Please set use_text_encoder to True in your model config to encode text."
             )
             )
         else:
         else:
             raise ValueError(f"Unsupported input_modality: {input_modality}")
             raise ValueError(f"Unsupported input_modality: {input_modality}")

+ 3 - 2
src/seamless_communication/models/inference/ngram_repeat_block_processor.py → src/seamless_communication/inference/ngram_repeat_block_processor.py

@@ -4,10 +4,11 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from fairseq2.generation import StepProcessor
 from typing import List
 from typing import List
-from torch import Tensor
+
 import torch
 import torch
+from fairseq2.generation import StepProcessor
+from torch import Tensor
 
 
 
 
 class NGramRepeatBlockProcessor(StepProcessor):
 class NGramRepeatBlockProcessor(StepProcessor):

+ 9 - 9
src/seamless_communication/models/inference/translator.py → src/seamless_communication/inference/translator.py

@@ -3,29 +3,31 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+import logging
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum, auto
 from enum import Enum, auto
 from pathlib import Path
 from pathlib import Path
-from torch import Tensor
 from typing import Callable, List, Optional, Tuple, Union, cast
 from typing import Callable, List, Optional, Tuple, Union, cast
 
 
-import logging
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
-
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData
 from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import TextTokenizer
 from fairseq2.data.text import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.data.typing import StringLike
-from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
+from fairseq2.generation import SequenceGeneratorOptions, SequenceToTextOutput
 from fairseq2.memory import MemoryBlock
 from fairseq2.memory import MemoryBlock
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
+from torch import Tensor
 
 
+from seamless_communication.inference.generator import (
+    SequenceToUnitOutput,
+    UnitYGenerator,
+)
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitTokenizer,
-    UnitYGenerator,
     UnitYModel,
     UnitYModel,
     UnitYNART2UModel,
     UnitYNART2UModel,
     UnitYT2UModel,
     UnitYT2UModel,
@@ -33,9 +35,7 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
     load_unity_unit_tokenizer,
 )
 )
-from seamless_communication.models.unity.generator import SequenceToUnitOutput
-from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
-
+from seamless_communication.models.vocoder import Vocoder, load_vocoder_model
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -246,7 +246,7 @@ class Translator(nn.Module):
                     audio = audio.unsqueeze(1)
                     audio = audio.unsqueeze(1)
                 elif audio.dim() == 2 and audio.size(0) < audio.size(1):
                 elif audio.dim() == 2 and audio.size(0) < audio.size(1):
                     logger.warning(
                     logger.warning(
-                        f"Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
+                        "Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
                     )
                     )
                     audio = audio.transpose(0, 1)
                     audio = audio.transpose(0, 1)
 
 

+ 0 - 14
src/seamless_communication/models/inference/__init__.py

@@ -1,14 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-from seamless_communication.models.inference.ngram_repeat_block_processor import (
-    NGramRepeatBlockProcessor as NGramRepeatBlockProcessor,
-)
-from seamless_communication.models.inference.translator import (
-    BatchedSpeechOutput as BatchedSpeechOutput,
-)
-from seamless_communication.models.inference.translator import Modality as Modality
-from seamless_communication.models.inference.translator import Task as Task
-from seamless_communication.models.inference.translator import Translator as Translator

+ 5 - 5
src/seamless_communication/models/unit_extraction/__init__.py → src/seamless_communication/models/unit_extractor/__init__.py

@@ -4,12 +4,12 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from seamless_communication.models.unit_extraction.unit_extraction import (
-    UnitExtractor as UnitExtractor,
-)
-from seamless_communication.models.unit_extraction.kmeans import (
+from seamless_communication.models.unit_extractor.kmeans import (
     KmeansModel as KmeansModel,
     KmeansModel as KmeansModel,
 )
 )
-from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
+from seamless_communication.models.unit_extractor.unit_extractor import (
+    UnitExtractor as UnitExtractor,
+)
+from seamless_communication.models.unit_extractor.wav2vec2_layer_output import (
     Wav2Vec2LayerOutputModel as Wav2Vec2LayerOutputModel,
     Wav2Vec2LayerOutputModel as Wav2Vec2LayerOutputModel,
 )
 )

+ 3 - 3
src/seamless_communication/models/unit_extraction/kmeans.py → src/seamless_communication/models/unit_extractor/kmeans.py

@@ -4,11 +4,11 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-import torch
-from torch import Tensor, nn
 import numpy as np
 import numpy as np
+import torch
+from fairseq2.assets import download_manager
 from fairseq2.typing import Device
 from fairseq2.typing import Device
-from seamless_communication.assets import download_manager
+from torch import Tensor, nn
 
 
 
 
 class KmeansModel(nn.Module):
 class KmeansModel(nn.Module):

+ 13 - 16
src/seamless_communication/models/unit_extraction/unit_extraction.py → src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -4,32 +4,29 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+import logging
 from itertools import groupby
 from itertools import groupby
 from pathlib import Path
 from pathlib import Path
-from torch import Tensor, nn
-from typing import Union
+from typing import List, Union
 
 
-import logging
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
-
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater
 from fairseq2.data import Collater
 from fairseq2.data.audio import AudioDecoder
 from fairseq2.data.audio import AudioDecoder
 from fairseq2.memory import MemoryBlock
 from fairseq2.memory import MemoryBlock
-from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.models.sequence import SequenceBatch
 from fairseq2.models.sequence import SequenceBatch
-from fairseq2.models.wav2vec2 import Wav2Vec2Model
+from fairseq2.models.wav2vec2 import Wav2Vec2Model, load_wav2vec2_model
+from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
+from torch import Tensor, nn
 
 
-from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
-    load_wav2vec2_model,
+from seamless_communication.inference import Translator
+from seamless_communication.models.unit_extractor.kmeans import KmeansModel
+from seamless_communication.models.unit_extractor.wav2vec2_layer_output import (
     Wav2Vec2LayerOutputModel,
     Wav2Vec2LayerOutputModel,
 )
 )
-from seamless_communication.models.unit_extraction.kmeans import KmeansModel
-from seamless_communication.models.inference import Translator
-from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
-
+from seamless_communication.models.vocoder import Vocoder, load_vocoder_model
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -77,7 +74,7 @@ class UnitExtractor(nn.Module):
                 audio = audio.unsqueeze(1)
                 audio = audio.unsqueeze(1)
             elif audio.dim() == 2 and audio.size(0) < audio.size(1):
             elif audio.dim() == 2 and audio.size(0) < audio.size(1):
                 logger.warning(
                 logger.warning(
-                    f"Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
+                    "Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
                 )
                 )
                 audio = audio.transpose(0, 1)
                 audio = audio.transpose(0, 1)
 
 
@@ -93,7 +90,7 @@ class UnitExtractor(nn.Module):
         batch = SequenceBatch(seqs=seqs, padding_mask=padding_mask)
         batch = SequenceBatch(seqs=seqs, padding_mask=padding_mask)
         features = self.model(batch, out_layer_idx).squeeze(0)
         features = self.model(batch, out_layer_idx).squeeze(0)
         units = self.kmeans_model(features)
         units = self.kmeans_model(features)
-        return units
+        return units  # type: ignore[no-any-return]
 
 
     @staticmethod
     @staticmethod
     def resynthesize_audio(
     def resynthesize_audio(
@@ -102,7 +99,7 @@ class UnitExtractor(nn.Module):
         device: Device,
         device: Device,
         vocoder_name: str = "vocoder_36langs",
         vocoder_name: str = "vocoder_36langs",
     ) -> Tensor:
     ) -> Tensor:
-        def reduce_list(lst):
+        def reduce_list(lst: List[Tensor]) -> List[Tensor]:
             return [key for key, _ in groupby(lst)]
             return [key for key, _ in groupby(lst)]
 
 
         reduced_units = reduce_list(units.cpu().tolist())
         reduced_units = reduce_list(units.cpu().tolist())
@@ -112,4 +109,4 @@ class UnitExtractor(nn.Module):
         )
         )
         assert isinstance(vocoder, Vocoder)
         assert isinstance(vocoder, Vocoder)
         wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
         wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
-        return wav
+        return wav  # type: ignore[no-any-return]

+ 10 - 30
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py → src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py

@@ -3,33 +3,21 @@
 #
 #
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
-from fairseq2.nn.padding import PaddingMask
-from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from fairseq2.models.sequence import SequenceBatch
 from fairseq2.models.wav2vec2 import (
 from fairseq2.models.wav2vec2 import (
-    Wav2Vec2EncoderConfig,
     Wav2Vec2Config,
     Wav2Vec2Config,
-    wav2vec2_arch,
-    Wav2Vec2Model,
-    create_wav2vec2_model,
+    Wav2Vec2EncoderConfig,
     Wav2Vec2Frontend,
     Wav2Vec2Frontend,
+    Wav2Vec2Model,
+    wav2vec2_arch,
 )
 )
-from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.models.sequence import SequenceBatch
-
-
-from seamless_communication.assets import asset_store, download_manager
-
-
-import torch
-from typing import Optional
-
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
 from torch import Tensor
 from torch import Tensor
-import torch.nn as nn
-
-
-wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
-wav2vec2_arch = wav2vec2_archs.marker
 
 
 
 
 def _encoder_xlsr2_1b_v2() -> Wav2Vec2EncoderConfig:
 def _encoder_xlsr2_1b_v2() -> Wav2Vec2EncoderConfig:
@@ -87,14 +75,6 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:
     )
     )
 
 
 
 
-load_wav2vec2_model = Wav2Vec2Loader(
-    asset_store,
-    download_manager,
-    create_wav2vec2_model,
-    wav2vec2_archs,
-)
-
-
 class Wav2Vec2LayerOutputModel(nn.Module):
 class Wav2Vec2LayerOutputModel(nn.Module):
     encoder_frontend: Wav2Vec2Frontend
     encoder_frontend: Wav2Vec2Frontend
     encoder: TransformerEncoder
     encoder: TransformerEncoder

+ 9 - 12
src/seamless_communication/models/unity/__init__.py

@@ -24,10 +24,10 @@ from seamless_communication.models.unity.length_regulator import (
     HardUpsampling as HardUpsampling,
     HardUpsampling as HardUpsampling,
 )
 )
 from seamless_communication.models.unity.length_regulator import (
 from seamless_communication.models.unity.length_regulator import (
-    VariancePredictor as VariancePredictor,
+    VarianceAdaptor as VarianceAdaptor,
 )
 )
 from seamless_communication.models.unity.length_regulator import (
 from seamless_communication.models.unity.length_regulator import (
-    VarianceAdaptor as VarianceAdaptor,
+    VariancePredictor as VariancePredictor,
 )
 )
 from seamless_communication.models.unity.loader import UnitYLoader as UnitYLoader
 from seamless_communication.models.unity.loader import UnitYLoader as UnitYLoader
 from seamless_communication.models.unity.loader import (
 from seamless_communication.models.unity.loader import (
@@ -40,26 +40,26 @@ from seamless_communication.models.unity.loader import (
     load_unity_unit_tokenizer as load_unity_unit_tokenizer,
     load_unity_unit_tokenizer as load_unity_unit_tokenizer,
 )
 )
 from seamless_communication.models.unity.model import UnitYModel as UnitYModel
 from seamless_communication.models.unity.model import UnitYModel as UnitYModel
-from seamless_communication.models.unity.model import UnitYX2TModel as UnitYX2TModel
-from seamless_communication.models.unity.model import UnitYT2UModel as UnitYT2UModel
 from seamless_communication.models.unity.model import (
 from seamless_communication.models.unity.model import (
     UnitYNART2UModel as UnitYNART2UModel,
     UnitYNART2UModel as UnitYNART2UModel,
 )
 )
 from seamless_communication.models.unity.model import UnitYOutput as UnitYOutput
 from seamless_communication.models.unity.model import UnitYOutput as UnitYOutput
-from seamless_communication.models.unity.nar_decoder_frontend import (
-    NARDecoderFrontend as NARDecoderFrontend,
-)
+from seamless_communication.models.unity.model import UnitYT2UModel as UnitYT2UModel
+from seamless_communication.models.unity.model import UnitYX2TModel as UnitYX2TModel
 from seamless_communication.models.unity.nar_decoder import (
 from seamless_communication.models.unity.nar_decoder import (
     NARTransformerDecoder as NARTransformerDecoder,
     NARTransformerDecoder as NARTransformerDecoder,
 )
 )
+from seamless_communication.models.unity.nar_decoder_frontend import (
+    NARDecoderFrontend as NARDecoderFrontend,
+)
 from seamless_communication.models.unity.nar_decoder_layer import (
 from seamless_communication.models.unity.nar_decoder_layer import (
     NARTransformerDecoderLayer as NARTransformerDecoderLayer,
     NARTransformerDecoderLayer as NARTransformerDecoderLayer,
 )
 )
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
-    UnitYT2UBuilder as UnitYT2UBuilder,
+    UnitYNART2UBuilder as UnitYNART2UBuilder,
 )
 )
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
-    UnitYNART2UBuilder as UnitYNART2UBuilder,
+    UnitYT2UBuilder as UnitYT2UBuilder,
 )
 )
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UConfig as UnitYT2UConfig,
     UnitYT2UConfig as UnitYT2UConfig,
@@ -82,6 +82,3 @@ from seamless_communication.models.unity.unit_tokenizer import (
 from seamless_communication.models.unity.unit_tokenizer import (
 from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenizer as UnitTokenizer,
     UnitTokenizer as UnitTokenizer,
 )
 )
-from seamless_communication.models.unity.generator import (
-    UnitYGenerator as UnitYGenerator,
-)

+ 0 - 1
src/seamless_communication/models/unity/adaptor_block.py

@@ -14,7 +14,6 @@ from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.transformer import (
 from fairseq2.nn.transformer import (
     AttentionMask,
     AttentionMask,
-    EncoderLayerOutputHook,
     FeedForwardNetwork,
     FeedForwardNetwork,
     LayerNormFactory,
     LayerNormFactory,
     MultiheadAttention,
     MultiheadAttention,

+ 3 - 5
src/seamless_communication/models/unity/builder.py

@@ -5,9 +5,8 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Union, Optional
+from typing import Optional, Union
 
 
-from fairseq2.data import VocabularyInfo
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
@@ -24,7 +23,6 @@ from fairseq2.nn.transformer import (
 )
 )
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 
 
-
 from seamless_communication.models.unity.adaptor_block import (
 from seamless_communication.models.unity.adaptor_block import (
     UnitYConformerAdaptorLayer,
     UnitYConformerAdaptorLayer,
     UnitYEncoderAdaptor,
     UnitYEncoderAdaptor,
@@ -32,15 +30,15 @@ from seamless_communication.models.unity.adaptor_block import (
 )
 )
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
-    UnitYT2UBuilder,
     UnitYNART2UBuilder,
     UnitYNART2UBuilder,
+    UnitYT2UBuilder,
     UnitYT2UConfig,
     UnitYT2UConfig,
     unity_t2u_archs,
     unity_t2u_archs,
 )
 )
 from seamless_communication.models.wav2vec2_chunk import (
 from seamless_communication.models.wav2vec2_chunk import (
-    wav2vec2_chunk_archs,
     Wav2Vec2ChunkEncoderBuilder,
     Wav2Vec2ChunkEncoderBuilder,
     Wav2Vec2ChunkEncoderConfig,
     Wav2Vec2ChunkEncoderConfig,
+    wav2vec2_chunk_archs,
 )
 )
 
 
 
 

+ 6 - 3
src/seamless_communication/models/unity/char_tokenizer.py

@@ -6,7 +6,12 @@
 
 
 from typing import Optional, Union, final
 from typing import Optional, Union, final
 
 
-from fairseq2.assets import AssetStore, AssetDownloadManager, download_manager
+from fairseq2.assets import (
+    AssetDownloadManager,
+    AssetStore,
+    asset_store,
+    download_manager,
+)
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data.text import (
 from fairseq2.data.text import (
     SentencePieceDecoder,
     SentencePieceDecoder,
@@ -20,8 +25,6 @@ from fairseq2.data.text import (
 from fairseq2.data.typing import PathLike
 from fairseq2.data.typing import PathLike
 from fairseq2.typing import Device, finaloverride
 from fairseq2.typing import Device, finaloverride
 
 
-from seamless_communication.assets import asset_store
-
 
 
 @final
 @final
 class CharTokenizer(TextTokenizer):
 class CharTokenizer(TextTokenizer):

+ 5 - 7
src/seamless_communication/models/unity/length_regulator.py

@@ -3,18 +3,16 @@
 #
 #
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
-import torch
-
-from torch import Tensor
-from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
-
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
-from fairseq2.typing import DataType, Device
-from fairseq2.nn.transformer import create_standard_layer_norm
+import torch
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.projection import Linear
+from fairseq2.nn.transformer import create_standard_layer_norm
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
 
 
 
 
 class HardUpsampling(Module):
 class HardUpsampling(Module):

+ 7 - 8
src/seamless_communication/models/unity/loader.py

@@ -7,10 +7,14 @@
 from typing import Any, Dict, List, Mapping, Union, final
 from typing import Any, Dict, List, Mapping, Union, final
 
 
 import torch
 import torch
-from fairseq2.assets import AssetStore, download_manager
+from fairseq2.assets import AssetStore, asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
+from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
+from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader
+from overrides import override as finaloverride
+
 from seamless_communication.models.unity.builder import (
 from seamless_communication.models.unity.builder import (
     UnitYConfig,
     UnitYConfig,
     create_unity_model,
     create_unity_model,
@@ -19,11 +23,6 @@ from seamless_communication.models.unity.builder import (
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
-from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
-from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader
-from overrides import override as finaloverride
-
-from seamless_communication.assets import asset_store
 
 
 
 
 @final
 @final
@@ -71,8 +70,8 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
 
 
-        keys_to_delete.append(f"decoder.char_upsampler.embed_positions._float_tensor")
-        keys_to_delete.append(f"decoder.char_upsampler.embed_tokens_char.weight")
+        keys_to_delete.append("decoder.char_upsampler.embed_positions._float_tensor")
+        keys_to_delete.append("decoder.char_upsampler.embed_tokens_char.weight")
 
 
         # Delete AlignmentEncoder keys for inference.
         # Delete AlignmentEncoder keys for inference.
         alignment_encoder_keys = [
         alignment_encoder_keys = [

+ 4 - 7
src/seamless_communication/models/unity/nar_decoder.py

@@ -6,17 +6,14 @@
 
 
 from typing import Iterable, Optional, Tuple, final
 from typing import Iterable, Optional, Tuple, final
 
 
-from torch import Tensor
-from torch.nn import Module
-
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
-from fairseq2.nn.transformer import (
-    TransformerNormOrder,
-    create_standard_layer_norm,
-)
+from fairseq2.nn.transformer import TransformerNormOrder, create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
 from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Module
+
 from seamless_communication.models.unity.nar_decoder_layer import (
 from seamless_communication.models.unity.nar_decoder_layer import (
     NARTransformerDecoderLayer,
     NARTransformerDecoderLayer,
 )
 )

+ 5 - 9
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -4,11 +4,10 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+import math
 from typing import List, Optional, Tuple, final
 from typing import List, Optional, Tuple, final
 
 
-from torch import Tensor
-from torch.nn import Dropout, Module, Parameter
-
+import torch
 from fairseq2.data import VocabularyInfo
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.embedding import Embedding
 from fairseq2.nn.embedding import Embedding
@@ -17,17 +16,14 @@ from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder
 from fairseq2.nn.position_encoder import PositionEncoder
 from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
 from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Dropout, Module, Parameter
 
 
-
+from seamless_communication.models.unity.char_tokenizer import CharTokenizer
 from seamless_communication.models.unity.length_regulator import (
 from seamless_communication.models.unity.length_regulator import (
     HardUpsampling,
     HardUpsampling,
     VarianceAdaptor,
     VarianceAdaptor,
 )
 )
-from seamless_communication.models.unity.char_tokenizer import CharTokenizer
-
-import math
-import torch
-
 
 
 SPACE = "▁"
 SPACE = "▁"
 
 

+ 4 - 5
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -4,15 +4,14 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from typing import Optional, final, Tuple
-
-from torch import Tensor
-from torch.nn import Conv1d, Dropout, Module, ReLU
+from typing import Optional, Tuple, final
 
 
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
-from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
+from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
 from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Conv1d, Dropout, Module, ReLU
 
 
 
 
 @final
 @final

+ 14 - 16
src/seamless_communication/models/unity/t2u_builder.py

@@ -6,9 +6,14 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Literal, Optional, Union
 from typing import Literal, Optional, Union
 
 
-from fairseq2.assets import download_manager
+from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import VocabularyInfo
 from fairseq2.data import VocabularyInfo
+from fairseq2.models.nllb.loader import NllbTokenizerLoader
+from fairseq2.models.transformer import (
+    TransformerEmbeddingFrontend,
+    TransformerFrontend,
+)
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
@@ -30,25 +35,18 @@ from fairseq2.nn.transformer import (
     create_default_sdpa,
     create_default_sdpa,
 )
 )
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
-from fairseq2.models.transformer import (
-    TransformerEmbeddingFrontend,
-    TransformerFrontend,
-)
-from fairseq2.models.nllb.loader import NllbTokenizerLoader
 
 
-
-from seamless_communication.assets import asset_store
-from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
-from seamless_communication.models.unity.nar_decoder_layer import (
-    NARTransformerDecoderLayer,
-    Conv1dBlock,
-)
-from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
-from seamless_communication.models.unity.model import UnitYT2UModel, UnitYNART2UModel
 from seamless_communication.models.unity.length_regulator import (
 from seamless_communication.models.unity.length_regulator import (
-    VariancePredictor,
     VarianceAdaptor,
     VarianceAdaptor,
+    VariancePredictor,
+)
+from seamless_communication.models.unity.model import UnitYNART2UModel, UnitYT2UModel
+from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
+from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
+from seamless_communication.models.unity.nar_decoder_layer import (
+    Conv1dBlock,
+    NARTransformerDecoderLayer,
 )
 )
 
 
 
 

+ 1 - 1
src/seamless_communication/models/vocoder/codehifigan.py

@@ -9,8 +9,8 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch import Tensor
 from torch import Tensor
 
 
-from seamless_communication.models.vocoder.hifigan import Generator
 from seamless_communication.models.unity import VariancePredictor
 from seamless_communication.models.unity import VariancePredictor
+from seamless_communication.models.vocoder.hifigan import Generator
 
 
 
 
 class CodeGenerator(Generator):
 class CodeGenerator(Generator):

+ 1 - 1
src/seamless_communication/models/vocoder/loader.py

@@ -6,10 +6,10 @@
 
 
 from typing import Any, Mapping, final
 from typing import Any, Mapping, final
 
 
+from fairseq2.assets import asset_store, download_manager
 from fairseq2.models.utils.model_loader import ModelLoader
 from fairseq2.models.utils.model_loader import ModelLoader
 from overrides import override as finaloverride
 from overrides import override as finaloverride
 
 
-from seamless_communication.assets import asset_store, download_manager
 from seamless_communication.models.vocoder.builder import (
 from seamless_communication.models.vocoder.builder import (
     VocoderConfig,
     VocoderConfig,
     create_vocoder_model,
     create_vocoder_model,

+ 3 - 3
src/seamless_communication/models/wav2vec2_chunk/__init__.py

@@ -4,12 +4,12 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from seamless_communication.models.wav2vec2_chunk.builder import (
-    wav2vec2_chunk_archs as wav2vec2_chunk_archs,
-)
 from seamless_communication.models.wav2vec2_chunk.builder import (
 from seamless_communication.models.wav2vec2_chunk.builder import (
     Wav2Vec2ChunkEncoderBuilder as Wav2Vec2ChunkEncoderBuilder,
     Wav2Vec2ChunkEncoderBuilder as Wav2Vec2ChunkEncoderBuilder,
 )
 )
 from seamless_communication.models.wav2vec2_chunk.builder import (
 from seamless_communication.models.wav2vec2_chunk.builder import (
     Wav2Vec2ChunkEncoderConfig as Wav2Vec2ChunkEncoderConfig,
     Wav2Vec2ChunkEncoderConfig as Wav2Vec2ChunkEncoderConfig,
 )
 )
+from seamless_communication.models.wav2vec2_chunk.builder import (
+    wav2vec2_chunk_archs as wav2vec2_chunk_archs,
+)

+ 2 - 2
src/seamless_communication/models/wav2vec2_chunk/builder.py

@@ -4,16 +4,16 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from dataclasses import dataclass, asdict
+from dataclasses import asdict, dataclass
 from typing import Literal, Optional
 from typing import Literal, Optional
 
 
 from fairseq2.models.conformer import ConformerConvolution
 from fairseq2.models.conformer import ConformerConvolution
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2.builder import (
 from fairseq2.models.wav2vec2.builder import (
     Wav2Vec2EncoderBuilder,
     Wav2Vec2EncoderBuilder,
     Wav2Vec2EncoderConfig,
     Wav2Vec2EncoderConfig,
 )
 )
-from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA
 from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 
 

+ 2 - 3
src/seamless_communication/models/wav2vec2_chunk/chunk_attention_mask.py

@@ -7,10 +7,9 @@
 from typing import Optional
 from typing import Optional
 
 
 import torch
 import torch
-from torch import Tensor
-
-from fairseq2.nn.utils.mask import to_float_mask
 from fairseq2.nn.transformer import AttentionMask, CustomAttentionMask
 from fairseq2.nn.transformer import AttentionMask, CustomAttentionMask
+from fairseq2.nn.utils.mask import to_float_mask
+from torch import Tensor
 
 
 
 
 class ChunkAttentionMaskFactory:
 class ChunkAttentionMaskFactory:

+ 4 - 11
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -6,25 +6,18 @@
 
 
 from typing import Iterable, Optional, Tuple, final
 from typing import Iterable, Optional, Tuple, final
 
 
-from torch import Tensor
-from torch.nn import Dropout
-
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
-
-from fairseq2.nn.transformer import (
-    EncoderLayerOutputHook,
-    TransformerEncoder,
-    TransformerEncoderLayer,
-)
+from fairseq2.nn.transformer import TransformerEncoder, TransformerEncoderLayer
+from fairseq2.typing import finaloverride
+from torch import Tensor
+from torch.nn import Dropout
 
 
 from seamless_communication.models.wav2vec2_chunk.chunk_attention_mask import (
 from seamless_communication.models.wav2vec2_chunk.chunk_attention_mask import (
     ChunkAttentionMaskFactory,
     ChunkAttentionMaskFactory,
 )
 )
 
 
-from fairseq2.typing import finaloverride
-
 
 
 @final
 @final
 class ChunkTransformerEncoder(TransformerEncoder):
 class ChunkTransformerEncoder(TransformerEncoder):

+ 0 - 0
src/seamless_communication/py.typed


+ 1 - 2
tests/common.py

@@ -8,9 +8,8 @@ from contextlib import contextmanager
 from typing import Any, Generator, List, Union
 from typing import Any, Generator, List, Union
 
 
 import torch
 import torch
-from torch import Tensor
-
 from fairseq2.typing import Device
 from fairseq2.typing import Device
+from torch import Tensor
 
 
 # The default device that tests should use. Note that pytest can change it based
 # The default device that tests should use. Note that pytest can change it based
 # on the provided command line arguments.
 # on the provided command line arguments.

+ 2 - 2
tests/conftest.py

@@ -8,10 +8,10 @@ from argparse import ArgumentTypeError
 from typing import cast
 from typing import cast
 
 
 import pytest
 import pytest
-import tests.common
-
 from fairseq2.typing import Device
 from fairseq2.typing import Device
 
 
+import tests.common
+
 
 
 def parse_device_arg(value: str) -> Device:
 def parse_device_arg(value: str) -> Device:
     try:
     try:

+ 0 - 0
tests/integration/inference/__init__.py


+ 3 - 2
tests/integration/models/test_translator.py → tests/integration/inference/test_translator.py

@@ -4,11 +4,12 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-import torch
 from typing import Final
 from typing import Final
 
 
+import torch
 from fairseq2.typing import Device
 from fairseq2.typing import Device
-from seamless_communication.models.inference import Translator
+
+from seamless_communication.inference import Translator
 from tests.common import device
 from tests.common import device
 
 
 # fmt: off
 # fmt: off

+ 6 - 6
tests/integration/models/test_unit_extraction.py → tests/integration/models/test_unit_extractor.py

@@ -4,22 +4,22 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-import torch
-from torch import tensor
 from typing import Final
 from typing import Final
 
 
+import torch
 from fairseq2.typing import Device
 from fairseq2.typing import Device
-from seamless_communication.models.inference import Translator
-from seamless_communication.models.unit_extraction import UnitExtractor
-from tests.common import assert_equal, device
+from torch import tensor
 
 
+from seamless_communication.inference import Translator
+from seamless_communication.models.unit_extractor import UnitExtractor
+from tests.common import assert_equal, device
 
 
 # fmt: off
 # fmt: off
 REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
 REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
 # fmt: on
 # fmt: on
 
 
 
 
-def test_unit_extraction() -> None:
+def test_unit_extractor() -> None:
     model_name = "seamlessM4T_v2_large"
     model_name = "seamlessM4T_v2_large"
     english_text = "Hello! I hope you're all doing well."
     english_text = "Hello! I hope you're all doing well."