Эх сурвалжийг харах

Revise project structure (#82)

Can Balioglu 1 жил өмнө
parent
commit
9daf4692ef
84 өөрчлөгдсөн 323 нэмэгдсэн , 405 устгасан
  1. 3 3
      README.md
  2. 1 1
      demo/app.py
  3. 2 1
      dev_requirements.txt
  4. 32 0
      pyproject.toml
  5. 0 7
      requirements.txt
  6. 0 0
      scripts/install_devfair.sh
  7. 0 0
      scripts/install_fairaws.sh
  8. 16 50
      setup.py
  9. 15 0
      src/seamless_communication/__init__.py
  10. 0 9
      src/seamless_communication/assets/__init__.py
  11. 0 27
      src/seamless_communication/assets/download_manager.py
  12. 0 22
      src/seamless_communication/assets/store.py
  13. 0 0
      src/seamless_communication/cards/seamlessM4T_large.yaml
  14. 0 0
      src/seamless_communication/cards/seamlessM4T_medium.yaml
  15. 0 0
      src/seamless_communication/cards/seamlessM4T_v2_large.yaml
  16. 0 0
      src/seamless_communication/cards/unity_nllb-100.yaml
  17. 0 0
      src/seamless_communication/cards/unity_nllb-200.yaml
  18. 0 0
      src/seamless_communication/cards/vocoder_36langs.yaml
  19. 0 0
      src/seamless_communication/cards/vocoder_v2.yaml
  20. 0 0
      src/seamless_communication/cards/xlsr2_1b_v2.yaml
  21. 0 0
      src/seamless_communication/cli/__init__.py
  22. 0 0
      src/seamless_communication/cli/eval_utils/__init__.py
  23. 6 5
      src/seamless_communication/cli/eval_utils/compute_metrics.py
  24. 0 1
      src/seamless_communication/cli/eval_utils/lang_mapping.py
  25. 0 0
      src/seamless_communication/cli/m4t/__init__.py
  26. 0 0
      src/seamless_communication/cli/m4t/audio_to_units/README.md
  27. 0 0
      src/seamless_communication/cli/m4t/audio_to_units/__init__.py
  28. 2 1
      src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py
  29. 0 0
      src/seamless_communication/cli/m4t/evaluate/README.md
  30. 0 0
      src/seamless_communication/cli/m4t/evaluate/__init__.py
  31. 15 18
      src/seamless_communication/cli/m4t/evaluate/evaluate.py
  32. 0 0
      src/seamless_communication/cli/m4t/finetune/README.md
  33. 0 0
      src/seamless_communication/cli/m4t/finetune/__init__.py
  34. 0 0
      src/seamless_communication/cli/m4t/finetune/dataloader.py
  35. 1 1
      src/seamless_communication/cli/m4t/finetune/dataset.py
  36. 0 0
      src/seamless_communication/cli/m4t/finetune/dist_utils.py
  37. 1 1
      src/seamless_communication/cli/m4t/finetune/finetune.py
  38. 11 9
      src/seamless_communication/cli/m4t/finetune/trainer.py
  39. 0 0
      src/seamless_communication/cli/m4t/predict/README.md
  40. 4 2
      src/seamless_communication/cli/m4t/predict/__init__.py
  41. 6 9
      src/seamless_communication/cli/m4t/predict/predict.py
  42. 0 0
      src/seamless_communication/cli/m4t/train/__init__.py
  43. 3 3
      src/seamless_communication/cli/m4t/train/configs.py
  44. 10 7
      src/seamless_communication/cli/m4t/train/dataloader.py
  45. 0 0
      src/seamless_communication/cli/m4t/train/dist_utils.py
  46. 26 17
      src/seamless_communication/cli/m4t/train/model.py
  47. 0 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small.yaml
  48. 0 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc.yaml
  49. 0 0
      src/seamless_communication/cli/m4t/train/recipes/large_M4T_v1.yaml
  50. 0 0
      src/seamless_communication/cli/m4t/train/recipes/m4t_v1_train_manifests.txt
  51. 6 5
      src/seamless_communication/cli/m4t/train/run_training.py
  52. 0 1
      src/seamless_communication/cli/m4t/train/run_with_slurm.py
  53. 17 13
      src/seamless_communication/cli/m4t/train/trainer.py
  54. 16 0
      src/seamless_communication/inference/__init__.py
  55. 5 7
      src/seamless_communication/inference/generator.py
  56. 3 2
      src/seamless_communication/inference/ngram_repeat_block_processor.py
  57. 9 9
      src/seamless_communication/inference/translator.py
  58. 0 14
      src/seamless_communication/models/inference/__init__.py
  59. 5 5
      src/seamless_communication/models/unit_extractor/__init__.py
  60. 3 3
      src/seamless_communication/models/unit_extractor/kmeans.py
  61. 13 16
      src/seamless_communication/models/unit_extractor/unit_extractor.py
  62. 10 30
      src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py
  63. 9 12
      src/seamless_communication/models/unity/__init__.py
  64. 0 1
      src/seamless_communication/models/unity/adaptor_block.py
  65. 3 5
      src/seamless_communication/models/unity/builder.py
  66. 6 3
      src/seamless_communication/models/unity/char_tokenizer.py
  67. 5 7
      src/seamless_communication/models/unity/length_regulator.py
  68. 7 8
      src/seamless_communication/models/unity/loader.py
  69. 4 7
      src/seamless_communication/models/unity/nar_decoder.py
  70. 5 9
      src/seamless_communication/models/unity/nar_decoder_frontend.py
  71. 4 5
      src/seamless_communication/models/unity/nar_decoder_layer.py
  72. 14 16
      src/seamless_communication/models/unity/t2u_builder.py
  73. 1 1
      src/seamless_communication/models/vocoder/codehifigan.py
  74. 1 1
      src/seamless_communication/models/vocoder/loader.py
  75. 3 3
      src/seamless_communication/models/wav2vec2_chunk/__init__.py
  76. 2 2
      src/seamless_communication/models/wav2vec2_chunk/builder.py
  77. 2 3
      src/seamless_communication/models/wav2vec2_chunk/chunk_attention_mask.py
  78. 4 11
      src/seamless_communication/models/wav2vec2_chunk/encoder.py
  79. 0 0
      src/seamless_communication/py.typed
  80. 1 2
      tests/common.py
  81. 2 2
      tests/conftest.py
  82. 0 0
      tests/integration/inference/__init__.py
  83. 3 2
      tests/integration/inference/test_translator.py
  84. 6 6
      tests/integration/models/test_unit_extractor.py

+ 3 - 3
README.md

@@ -45,7 +45,7 @@ T2TT task:
 m4t_predict <input_text> t2tt <tgt_lang> --src_lang <src_lang>
 ```
 
-Please refer to the [inference README](scripts/m4t/predict) for detailed instruction on how to run inference and the list of supported languages on the source, target sides for speech, text modalities.
+Please refer to the [inference README](src/seamless_communication/cli/m4t/predict) for detailed instruction on how to run inference and the list of supported languages on the source, target sides for speech, text modalities.
 
 ## Running [Gradio](https://github.com/gradio-app/gradio) demo locally
 
@@ -86,10 +86,10 @@ We provide the extensive evaluation results of seamlessM4T-Large and SeamlessM4T
 To reproduce our results, or to evaluate using the same metrics over your own test sets, please check out the [README here](docs/m4t/eval_README.md).
 
 ## Finetuning SeamlessM4T models
-Please check out the [README here](scripts/m4t/finetune/README.md).
+Please check out the [README here](src/seamless_communication/cli/m4t/finetune/README.md).
 
 ## Converting raw audio to units
-Please check out the [README here](scripts/m4t/audio_to_units/README.md).
+Please check out the [README here](src/seamless_communication/cli/m4t/audio_to_units/README.md).
 
 ## On-device models
 Apart from Seamless-M4T large (2.3B) and medium (1.2B) models, we are also releasing a small model (281M) targeted for on-device inference. To learn more about the usage and model details check out the [README here](docs/m4t/on_device_README.md).

+ 1 - 1
demo/app.py

@@ -11,8 +11,8 @@ import numpy as np
 import torch
 import torchaudio
 from huggingface_hub import hf_hub_download
-from seamless_communication.models.inference.translator import Translator
 
+from seamless_communication.models.inference.translator import Translator
 
 DESCRIPTION = """# SeamlessM4T
 

+ 2 - 1
dev_requirements.txt

@@ -1,5 +1,6 @@
-pytest
 black
 flake8
 isort
 mypy
+pre-commit
+pytest

+ 32 - 0
pyproject.toml

@@ -0,0 +1,32 @@
+[build-system]
+requires = ["packaging~=23.1", "setuptools~=67.8", "wheel~=0.40"]
+build-backend = "setuptools.build_meta"
+
+[tool.flake8]
+extend_ignore = ["E", "Y"]  # Black
+per-file-ignores = [
+    "__init__.py:F401",
+]
+
+[tool.isort]
+profile = "black"
+
+[tool.mypy]
+disable_error_code = "type-abstract"
+disallow_untyped_calls = false
+disallow_untyped_decorators = false
+ignore_missing_imports = true
+python_version = 3.8
+show_error_codes = true
+show_error_context = true
+strict = true
+warn_unused_configs = false
+warn_unused_ignores = false
+
+[tool.pytest.ini_options]
+minversion = "7.1"
+testpaths = ["tests"]
+filterwarnings = [
+    "ignore:torch.nn.utils.weight_norm is deprecated in favor of",
+    "ignore:TypedStorage is deprecated",
+]

+ 0 - 7
requirements.txt

@@ -1,7 +0,0 @@
-pre-commit
-datasets
-torchaudio
-tqdm
-soundfile
-librosa
-fairseq2==0.2.*

+ 0 - 0
scripts/m4t/train/install_devfair.sh → scripts/install_devfair.sh


+ 0 - 0
scripts/m4t/train/install_fairaws.sh → scripts/install_fairaws.sh


+ 16 - 50
setup.py

@@ -4,53 +4,14 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from pathlib import Path
-import os
-from typing import Iterable
-
-import pkg_resources
 from setuptools import find_packages, setup
-from setuptools.command.develop import develop
-
-
-def _load_requirements(fname: str) -> Iterable[str]:
-    with open(Path(__file__).parent / fname) as fp_in:
-        for req in pkg_resources.parse_requirements(fp_in):
-            yield str(req)
-
-
-def _add_symlinks():
-    root = Path(__file__).parent
-    sc_root = root / "src/seamless_communication"
-    sc_link = root / "seamless_communication"
-    m4t_scripts_root = root / "scripts/m4t"
-    m4t_scripts_link = root / "m4t_scripts"
-    if not sc_link.exists():
-        os.symlink(sc_root, sc_link, target_is_directory=True)
-    if not m4t_scripts_link.exists():
-        os.symlink(m4t_scripts_root, m4t_scripts_link, target_is_directory=True)
-
-
-class cmd_for_editable_mode(develop):
-    def run(self):
-        # add symlinks for modules if install in editable mode
-        _add_symlinks()
-        super().run()
-
-
-default_requirements = list(_load_requirements("requirements.txt"))
-dev_requirements = list(_load_requirements("dev_requirements.txt"))
 
 setup(
     name="seamless_communication",
     version="1.0.0",
-    packages=find_packages(where="src")
-    + ["m4t_scripts.finetune", "m4t_scripts.predict"],
-    package_dir={
-        "m4t_scripts": "scripts/m4t",
-        "seamless_communication": "src/seamless_communication",
-    },
-    package_data={"": ["assets/cards/*.yaml"]},
+    packages=find_packages(where="src"),
+    package_dir={"": "src"},
+    package_data={"": ["py.typed", "cards/*.yaml"]},
     description="SeamlessM4T -- Massively Multilingual & Multimodal Machine Translation Model",
     long_description=open("README.md", encoding="utf-8").read(),
     long_description_content_type="text/markdown",
@@ -59,17 +20,22 @@ setup(
     author="Fundamental AI Research (FAIR) at Meta",
     url="https://github.com/facebookresearch/seamless_communication",
     license="Creative Commons",
-    install_requires=default_requirements,
-    extras_require={"dev": default_requirements + dev_requirements},
+    install_requires=[
+        "datasets",
+        "fairseq2==0.2.*",
+        "librosa",
+        "soundfile",
+        "torchaudio",
+        "tqdm",
+    ],
     entry_points={
         "console_scripts": [
-            "m4t_evaluate=m4t_scripts.evaluate.evaluate:main",
-            "m4t_predict=m4t_scripts.predict.predict:main",
-            "m4t_finetune=m4t_scripts.finetune.finetune:main",
-            "m4t_prepare_dataset=m4t_scripts.finetune.dataset:main",
-            "m4t_audio_to_units=m4t_scripts.audio_to_units.audio_to_units:main",
+            "m4t_evaluate=seamless_communication.cli.m4t.evaluate.evaluate:main",
+            "m4t_predict=seamless_communication.cli.m4t.predict.predict:main",
+            "m4t_finetune=seamless_communication.cli.m4t.finetune.finetune:main",
+            "m4t_prepare_dataset=seamless_communication.cli.m4t.finetune.dataset:main",
+            "m4t_audio_to_units=seamless_communication.cli.m4t.audio_to_units.audio_to_units:main",
         ],
     },
-    cmdclass={"develop": cmd_for_editable_mode},
     include_package_data=True,
 )

+ 15 - 0
src/seamless_communication/__init__.py

@@ -4,4 +4,19 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+from pathlib import Path
+
+from fairseq2.assets import LocalAssetCardStorage, asset_store
+
 __version__ = "0.1.0"
+
+
+def _update_asset_store() -> None:
+    pathname = Path(__file__).parent.joinpath("cards")
+
+    card_storage = LocalAssetCardStorage(pathname)
+
+    asset_store.add_storage(card_storage)
+
+
+_update_asset_store()

+ 0 - 9
src/seamless_communication/assets/__init__.py

@@ -1,9 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-from seamless_communication.assets.download_manager import (
-    download_manager as download_manager,
-)
-from seamless_communication.assets.store import asset_store as asset_store

+ 0 - 27
src/seamless_communication/assets/download_manager.py

@@ -1,27 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from pathlib import Path
-
-import torch
-from fairseq2.assets import DefaultAssetDownloadManager
-
-
-class SCAssetDownloadManager(DefaultAssetDownloadManager):
-    @classmethod
-    def _get_pathname(cls, uri: str, sub_dir: str) -> Path:
-        hub_dir = Path(torch.hub.get_dir()).expanduser()
-
-        hsh = cls._get_uri_hash(uri)
-
-        filename = cls._get_filename(uri)
-
-        return hub_dir.joinpath(
-            "seamless_communication", "assets", sub_dir, hsh, filename
-        )
-
-
-download_manager = SCAssetDownloadManager()

+ 0 - 22
src/seamless_communication/assets/store.py

@@ -1,22 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from pathlib import Path
-
-from fairseq2.assets import AssetStore
-from fairseq2.assets.card_storage import LocalAssetCardStorage
-from fairseq2.assets.store import DefaultAssetStore
-
-
-def create_default_asset_store() -> AssetStore:
-    pathname = Path(__file__).parent.joinpath("cards")
-
-    card_storage = LocalAssetCardStorage(pathname)
-
-    return DefaultAssetStore(card_storage)
-
-
-asset_store = create_default_asset_store()

+ 0 - 0
src/seamless_communication/assets/cards/seamlessM4T_large.yaml → src/seamless_communication/cards/seamlessM4T_large.yaml


+ 0 - 0
src/seamless_communication/assets/cards/seamlessM4T_medium.yaml → src/seamless_communication/cards/seamlessM4T_medium.yaml


+ 0 - 0
src/seamless_communication/assets/cards/seamlessM4T_v2_large.yaml → src/seamless_communication/cards/seamlessM4T_v2_large.yaml


+ 0 - 0
src/seamless_communication/assets/cards/unity_nllb-100.yaml → src/seamless_communication/cards/unity_nllb-100.yaml


+ 0 - 0
src/seamless_communication/assets/cards/unity_nllb-200.yaml → src/seamless_communication/cards/unity_nllb-200.yaml


+ 0 - 0
src/seamless_communication/assets/cards/vocoder_36langs.yaml → src/seamless_communication/cards/vocoder_36langs.yaml


+ 0 - 0
src/seamless_communication/assets/cards/vocoder_v2.yaml → src/seamless_communication/cards/vocoder_v2.yaml


+ 0 - 0
src/seamless_communication/assets/cards/xlsr2_1b_v2.yaml → src/seamless_communication/cards/xlsr2_1b_v2.yaml


+ 0 - 0
scripts/m4t/train/__init__.py → src/seamless_communication/cli/__init__.py


+ 0 - 0
src/seamless_communication/cli/eval_utils/__init__.py


+ 6 - 5
scripts/eval_utils/compute_metrics.py → src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -4,17 +4,18 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from pathlib import Path
 import logging
+from pathlib import Path
+from typing import Optional
+
 import pandas as pd
 import sacrebleu
 import whisper
-from jiwer import wer, cer
+from jiwer import cer, wer
 from tqdm import tqdm
-from typing import Optional
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
-from scripts.eval_utils.lang_mapping import LANG3_LANG2
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 
 logging.basicConfig(
     level=logging.INFO,
@@ -316,7 +317,7 @@ def compute_quality_metrics(
             whisper_normalize_text=True,
         )
         transcripts_df.to_csv(
-            (Path(output_dir) / f"whisper_audio_transcriptions.tsv"),
+            (Path(output_dir) / "whisper_audio_transcriptions.tsv"),
             sep="\t",
             index=False,
             encoding="utf-8",

+ 0 - 1
scripts/eval_utils/lang_mapping.py → src/seamless_communication/cli/eval_utils/lang_mapping.py

@@ -174,4 +174,3 @@ LANG2_LANG3 = {
     "tk": "tuk",
 }
 LANG3_LANG2 = {v: k for k, v in LANG2_LANG3.items()}
-

+ 0 - 0
src/seamless_communication/cli/m4t/__init__.py


+ 0 - 0
scripts/m4t/audio_to_units/README.md → src/seamless_communication/cli/m4t/audio_to_units/README.md


+ 0 - 0
scripts/m4t/audio_to_units/__init__.py → src/seamless_communication/cli/m4t/audio_to_units/__init__.py


+ 2 - 1
scripts/m4t/audio_to_units/audio_to_units.py → src/seamless_communication/cli/m4t/audio_to_units/audio_to_units.py

@@ -5,9 +5,10 @@
 
 import argparse
 import logging
+
 import torch
-from seamless_communication.models.unit_extraction import UnitExtractor
 
+from seamless_communication.models.unit_extractor import UnitExtractor
 
 logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)

+ 0 - 0
scripts/m4t/evaluate/README.md → src/seamless_communication/cli/m4t/evaluate/README.md


+ 0 - 0
scripts/m4t/evaluate/__init__.py → src/seamless_communication/cli/m4t/evaluate/__init__.py


+ 15 - 18
scripts/m4t/evaluate/evaluate.py → src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -9,33 +9,31 @@ import contextlib
 import itertools
 import logging
 import subprocess
-import torch
-import torchaudio
-
 from argparse import Namespace
 from dataclasses import dataclass
 from pathlib import Path
-from torch import Tensor
-from tqdm import tqdm
-from typing import List, Optional, Tuple, Dict
+from typing import Dict, List, Optional, Tuple
 
+import torch
+import torchaudio
 from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceGeneratorOptions
-from fairseq2.typing import Device, DataType
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+from tqdm import tqdm
 
-from m4t_scripts.predict import add_inference_arguments, set_generation_opts
-from seamless_communication.models.inference import (
-    BatchedSpeechOutput,
-    Modality,
-    Translator,
-)
-from seamless_communication.models.unity import load_unity_text_tokenizer
-from scripts.eval_utils.compute_metrics import (
+from seamless_communication.cli.eval_utils.compute_metrics import (
     compute_quality_metrics,
 )
+from seamless_communication.cli.predict import (
+    add_inference_arguments,
+    set_generation_opts,
+)
+from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
+from seamless_communication.models.unity import load_unity_text_tokenizer
 
 logging.basicConfig(
     level=logging.INFO,
@@ -247,9 +245,9 @@ def run_eval(
     ) as unit_file:
         sample_id = 0
         if ctx.output_modality == Modality.SPEECH:
-            hyp_file.write(f"ref_tgt_text\tpred_tgt_text\tpred_tgt_audio\n")
+            hyp_file.write("ref_tgt_text\tpred_tgt_text\tpred_tgt_audio\n")
         else:
-            hyp_file.write(f"ref_tgt_text\tpred_tgt_text\n")
+            hyp_file.write("ref_tgt_text\tpred_tgt_text\n")
         for example in pipeline:
             valid_sequences: Optional[Tensor] = None
             if ctx.input_modality == Modality.SPEECH:
@@ -302,7 +300,6 @@ def run_eval(
             refs = [str(s) for s in example[ctx.ref_field]]
 
             for i in range(len(text_output)):
-                t = text_output[i]
                 if ctx.output_modality == Modality.SPEECH:
                     assert speech_output is not None
                     u = speech_output.units[i]

+ 0 - 0
scripts/m4t/finetune/README.md → src/seamless_communication/cli/m4t/finetune/README.md


+ 0 - 0
scripts/m4t/finetune/__init__.py → src/seamless_communication/cli/m4t/finetune/__init__.py


+ 0 - 0
scripts/m4t/finetune/dataloader.py → src/seamless_communication/cli/m4t/finetune/dataloader.py


+ 1 - 1
scripts/m4t/finetune/dataset.py → src/seamless_communication/cli/m4t/finetune/dataset.py

@@ -18,7 +18,7 @@ from seamless_communication.datasets.huggingface import (
     Speech2SpeechFleursDatasetBuilder,
     SpeechTokenizer,
 )
-from seamless_communication.models.unit_extraction import UnitExtractor
+from seamless_communication.models.unit_extractor import UnitExtractor
 
 logging.basicConfig(
     level=logging.INFO,

+ 0 - 0
scripts/m4t/finetune/dist_utils.py → src/seamless_communication/cli/m4t/finetune/dist_utils.py


+ 1 - 1
scripts/m4t/finetune/finetune.py → src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -11,8 +11,8 @@ from pathlib import Path
 
 import torch
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
-from m4t_scripts.finetune import dataloader, dist_utils, trainer
 
+from seamless_communication.cli.m4t.finetune import dataloader, dist_utils, trainer
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitYModel,

+ 11 - 9
scripts/m4t/finetune/trainer.py → src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -15,12 +15,14 @@ from typing import Optional, Tuple
 import torch
 import torch.distributed as dist
 import torch.nn as nn
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.sequence import SequenceModelOutput
+from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
-from m4t_scripts.finetune import dataloader, dist_utils
 from torch.optim import Adam
 
+from seamless_communication.cli.finetune import dataloader, dist_utils
 from seamless_communication.models.unity import UnitYModel
 
 logger = logging.getLogger(__name__)
@@ -136,12 +138,12 @@ class CalcLoss:
     def __init__(
         self,
         label_smoothing: float,
-        s2t_pad_idx: Optional[int],
-        t2u_pad_idx: Optional[int],
+        s2t_vocab_info: VocabularyInfo,
+        t2u_vocab_info: VocabularyInfo,
     ):
         self.label_smoothing = label_smoothing
-        self.s2t_pad_idx = s2t_pad_idx
-        self.t2u_pad_idx = t2u_pad_idx
+        self.s2t_vocab_info = s2t_vocab_info
+        self.t2u_vocab_info = t2u_vocab_info
 
     def __call__(
         self,
@@ -154,7 +156,7 @@ class CalcLoss:
             text_logits.device
         )
         s2t_loss = SequenceModelOutput(
-            logits=text_logits, pad_idx=self.s2t_pad_idx
+            logits=text_logits, vocab_info=self.s2t_vocab_info
         ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             ignore_prefix_size=1,
@@ -165,7 +167,7 @@ class CalcLoss:
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
-            logits=unit_logits, pad_idx=self.t2u_pad_idx
+            logits=unit_logits, vocab_info=self.t2u_vocab_info
         ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             ignore_prefix_size=1,
@@ -227,8 +229,8 @@ class UnitYFinetune:
         assert model.t2u_model is not None
         self.calc_loss = CalcLoss(
             label_smoothing=self.params.label_smoothing,
-            s2t_pad_idx=model.pad_idx,
-            t2u_pad_idx=model.t2u_model.pad_idx,
+            s2t_vocab_info=model.target_vocab_info,
+            t2u_vocab_info=model.t2u_model.target_vocab_info,
         )
         self.model = self._wrap_model_for_trainining(model=model)
         self.train_data_loader = train_data_loader

+ 0 - 0
scripts/m4t/predict/README.md → src/seamless_communication/cli/m4t/predict/README.md


+ 4 - 2
scripts/m4t/predict/__init__.py → src/seamless_communication/cli/m4t/predict/__init__.py

@@ -4,7 +4,9 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from m4t_scripts.predict.predict import (
+from seamless_communication.cli.m4t.predict.predict import (
     add_inference_arguments as add_inference_arguments,
 )
-from m4t_scripts.predict.predict import set_generation_opts as set_generation_opts
+from seamless_communication.cli.m4t.predict.predict import (
+    set_generation_opts as set_generation_opts,
+)

+ 6 - 9
scripts/m4t/predict/predict.py → src/seamless_communication/cli/m4t/predict/predict.py

@@ -5,17 +5,14 @@
 
 import argparse
 import logging
+from argparse import Namespace
+from typing import Tuple
+
 import torch
 import torchaudio
-
-from argparse import Namespace
 from fairseq2.generation import SequenceGeneratorOptions
-from seamless_communication.models.inference import (
-    NGramRepeatBlockProcessor,
-    Translator,
-)
-from typing import Tuple
 
+from seamless_communication.inference import NGramRepeatBlockProcessor, Translator
 
 logging.basicConfig(
     level=logging.INFO,
@@ -152,7 +149,7 @@ def set_generation_opts(
         ),
     )
     if args.text_generation_ngram_blocking:
-        text_generation_opts.logits_processor = NGramRepeatBlockProcessor(
+        text_generation_opts.step_processor = NGramRepeatBlockProcessor(
             no_repeat_ngram_size=args.no_repeat_ngram_size
         )
 
@@ -164,7 +161,7 @@ def set_generation_opts(
         ),
     )
     if args.unit_generation_ngram_blocking:
-        unit_generation_opts.logits_processor = NGramRepeatBlockProcessor(
+        unit_generation_opts.step_processor = NGramRepeatBlockProcessor(
             no_repeat_ngram_size=args.no_repeat_ngram_size
         )
     return text_generation_opts, unit_generation_opts

+ 0 - 0
src/seamless_communication/cli/m4t/train/__init__.py


+ 3 - 3
scripts/m4t/train/configs.py → src/seamless_communication/cli/m4t/train/configs.py

@@ -4,10 +4,10 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-import yaml
-
 from dataclasses import dataclass
-from typing import Dict, Any, Union, get_origin, get_args, List, Literal, Optional
+from typing import Any, Dict, List, Literal, Optional, Union, get_args, get_origin
+
+import yaml
 
 
 @dataclass

+ 10 - 7
scripts/m4t/train/dataloader.py → src/seamless_communication/cli/m4t/train/dataloader.py

@@ -5,15 +5,12 @@
 # LICENSE file in the root directory of this source tree.
 
 
+import ctypes
 import logging
 import os
 from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
-import ctypes
 
 import torch
-from m4t_scripts.train.configs import AudioProcessingConfig, DataLoadingConfig
-from torch import Tensor
-
 from fairseq2.data import (
     CollateOptionsOverride,
     Collater,
@@ -24,6 +21,12 @@ from fairseq2.data import (
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import SentencePieceEncoder, StrSplitter, read_text
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from torch import Tensor
+
+from seamless_communication.cli.m4t.train.configs import (
+    AudioProcessingConfig,
+    DataLoadingConfig,
+)
 from seamless_communication.models.tokenizer import SPMTokenizer
 from seamless_communication.models.unity import (
     UnitTokenizer,
@@ -419,15 +422,15 @@ class UnityDataLoader:
             overrides=[
                 CollateOptionsOverride(
                     selector=f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}.data.fbank",
-                    pad_idx=self.config.fbank_feats_pad_idx,
+                    pad_value=self.config.fbank_feats_pad_idx,
                 ),
                 CollateOptionsOverride(
                     selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
-                    pad_idx=self.text_tokenizer.vocab_info.pad_idx,
+                    pad_value=self.text_tokenizer.vocab_info.pad_idx,
                 ),
                 CollateOptionsOverride(
                     selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
-                    pad_idx=self.unit_tokenizer.vocab_info.pad_idx,
+                    pad_value=self.unit_tokenizer.vocab_info.pad_idx,
                 ),
             ],
         )

+ 0 - 0
scripts/m4t/train/dist_utils.py → src/seamless_communication/cli/m4t/train/dist_utils.py


+ 26 - 17
scripts/m4t/train/model.py → src/seamless_communication/cli/m4t/train/model.py

@@ -7,27 +7,26 @@
 
 import logging
 import os
-from typing import Dict, Any
+from typing import Any, Dict
 
 import torch
-from m4t_scripts.train.configs import CustomModelParams, ModelConfig
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.nllb.builder import NllbConfig
+from fairseq2.models.nllb.loader import NllbLoader
+from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
+from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
+from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
+from fairseq2.nn.transformer import TransformerNormOrder
 
+from seamless_communication.cli.m4t.train.configs import CustomModelParams, ModelConfig
 from seamless_communication.models.unity import (
     UnitYConfig,
     UnitYModel,
-    load_unity_model,
+    UnitYT2UConfig,
     create_unity_model,
+    load_unity_model,
 )
-from seamless_communication.models.unity.loader import load_unity_config
-from seamless_communication.models.unity import UnitYT2UConfig
-from fairseq2.nn.transformer import TransformerNormOrder
-from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
-from fairseq2.models.nllb.builder import NllbConfig
-from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
-from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
-from seamless_communication.models.unity.loader import UnitYLoader
-
-from fairseq2.models.nllb.loader import NllbLoader
+from seamless_communication.models.unity.loader import UnitYLoader, load_unity_config
 
 logger = logging.getLogger(__name__)
 
@@ -257,8 +256,13 @@ class ModelBuilder:
             mt_model_config=NllbConfig(
                 model_dim=config.model_embed_dim,
                 max_seq_len=1024,
-                vocabulary_size=config.nllb_vocabulary_size,  # num_tokens + langs + spec symbols
-                pad_idx=0,
+                vocab_info=VocabularyInfo(
+                    size=config.nllb_vocabulary_size,
+                    unk_idx=1,
+                    bos_idx=2,
+                    eos_idx=3,
+                    pad_idx=0,
+                ),
                 num_encoder_layers=config.nllb_encoder_layers,
                 num_decoder_layers=config.nllb_decoder_layers,
                 num_encoder_attn_heads=16,
@@ -269,8 +273,13 @@ class ModelBuilder:
             t2u_config=UnitYT2UConfig(
                 model_dim=config.model_embed_dim,
                 unit_max_seq_len=2048,
-                unit_vocabulary_size=config.unit_vocabulary_size,
-                unit_pad_idx=1,
+                target_vocab_info=VocabularyInfo(
+                    size=config.unit_vocabulary_size,
+                    unk_idx=3,
+                    bos_idx=0,
+                    eos_idx=2,
+                    pad_idx=1,
+                ),
                 num_encoder_layers=config.t2u_encoder_layers,
                 num_decoder_layers=config.t2u_decoder_layers,
                 nar_decoder_frontend_config=None,

+ 0 - 0
scripts/m4t/train/recipes/asr_small.yaml → src/seamless_communication/cli/m4t/train/recipes/asr_small.yaml


+ 0 - 0
scripts/m4t/train/recipes/asr_small_wh_transc.yaml → src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc.yaml


+ 0 - 0
scripts/m4t/train/recipes/large_M4T_v1.yaml → src/seamless_communication/cli/m4t/train/recipes/large_M4T_v1.yaml


+ 0 - 0
scripts/m4t/train/recipes/m4t_v1_train_manifests.txt → src/seamless_communication/cli/m4t/train/recipes/m4t_v1_train_manifests.txt


+ 6 - 5
scripts/m4t/train/run_training.py → src/seamless_communication/cli/m4t/train/run_training.py

@@ -15,11 +15,12 @@ from typing import List
 
 import torch
 import yaml
-from m4t_scripts.train import dataloader as _dataloader
-from m4t_scripts.train import dist_utils
-from m4t_scripts.train import model as _model
-from m4t_scripts.train import trainer as _trainer
-from m4t_scripts.train.configs import WorkflowParams
+
+from seamless_communication.cli.m4t.train import dataloader as _dataloader
+from seamless_communication.cli.m4t.train import dist_utils
+from seamless_communication.cli.m4t.train import model as _model
+from seamless_communication.cli.m4t.train import trainer as _trainer
+from seamless_communication.cli.m4t.train.configs import WorkflowParams
 
 logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
 logging.basicConfig(

+ 0 - 1
scripts/m4t/train/run_with_slurm.py → src/seamless_communication/cli/m4t/train/run_with_slurm.py

@@ -7,7 +7,6 @@ import subprocess
 import time
 from pathlib import Path
 
-
 logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
 logging.basicConfig(
     level=logging.INFO,

+ 17 - 13
scripts/m4t/train/trainer.py → src/seamless_communication/cli/m4t/train/trainer.py

@@ -6,21 +6,22 @@
 
 
 import logging
-from typing import Any, Optional, Tuple, Dict, List
-
 import os
 import time
+from typing import Any, Dict, List, Optional, Tuple
+
 import torch
 import torch.distributed as dist
 import torch.nn as nn
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
-from m4t_scripts.train import dataloader, dist_utils
 from torch.optim import Adam
 
+from seamless_communication.cli.m4t.train import dataloader, dist_utils
+from seamless_communication.cli.m4t.train.configs import TrainingParams
 from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
-from m4t_scripts.train.configs import TrainingParams
 
 logger = logging.getLogger(__name__)
 
@@ -67,7 +68,10 @@ class UnitYTrainWrapper(nn.Module):
         )
         text_logits = self.model.final_proj(text_decoder_out)
         # t2u
-        (unit_encoder_out, unit_encoder_padding_mask,) = self.t2u.encode(
+        (
+            unit_encoder_out,
+            unit_encoder_padding_mask,
+        ) = self.t2u.encode(
             text_decoder_output=text_decoder_out,
             text_decoder_padding_mask=text_decoder_padding_mask,
         )
@@ -91,13 +95,13 @@ class CalcLoss:
     def __init__(
         self,
         label_smoothing: float,
-        s2t_pad_idx: Optional[int],
-        t2u_pad_idx: Optional[int],
+        s2t_vocab_info: VocabularyInfo,
+        t2u_vocab_info: VocabularyInfo,
         s2t_skip_langtok_loss: bool = False,
     ):
         self.label_smoothing = label_smoothing
-        self.s2t_pad_idx = s2t_pad_idx
-        self.t2u_pad_idx = t2u_pad_idx
+        self.s2t_vocab_info = s2t_vocab_info
+        self.t2u_vocab_info = t2u_vocab_info
         self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
         self.t2u_ignore_prefix_size = 1
 
@@ -112,7 +116,7 @@ class CalcLoss:
             text_logits.device
         )
         s2t_loss = SequenceModelOutput(
-            logits=text_logits, pad_idx=self.s2t_pad_idx
+            logits=text_logits, vocab_info=self.s2t_vocab_info
         ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             ignore_prefix_size=self.s2t_ignore_prefix_size,
@@ -121,7 +125,7 @@ class CalcLoss:
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
-            logits=unit_logits, pad_idx=self.t2u_pad_idx
+            logits=unit_logits, vocab_info=self.t2u_vocab_info
         ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             ignore_prefix_size=1,
@@ -192,8 +196,8 @@ class UnitYTrainer:
         assert model.t2u_model is not None
         self.calc_loss = CalcLoss(
             label_smoothing=self.params.label_smoothing,
-            s2t_pad_idx=model.pad_idx,
-            t2u_pad_idx=model.t2u_model.pad_idx,
+            s2t_vocab_info=model.target_vocab_info,
+            t2u_vocab_info=model.t2u_model.target_vocab_info,
         )
         self._try_load_checkpoint(model=model)
         self.model = self._wrap_model_for_trainining(model=model)

+ 16 - 0
src/seamless_communication/inference/__init__.py

@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
+from seamless_communication.inference.ngram_repeat_block_processor import (
+    NGramRepeatBlockProcessor as NGramRepeatBlockProcessor,
+)
+from seamless_communication.inference.translator import (
+    BatchedSpeechOutput as BatchedSpeechOutput,
+)
+from seamless_communication.inference.translator import Modality as Modality
+from seamless_communication.inference.translator import Task as Task
+from seamless_communication.inference.translator import Translator as Translator

+ 5 - 7
src/seamless_communication/models/unity/generator.py → src/seamless_communication/inference/generator.py

@@ -5,12 +5,9 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 import torch
-
-from torch import Tensor
-from fairseq2.data import VocabularyInfo
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
     Seq2SeqGenerator,
@@ -21,11 +18,12 @@ from fairseq2.generation import (
 )
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.utils.module import infer_device
+from torch import Tensor
 
 from seamless_communication.models.unity.model import (
     UnitYModel,
-    UnitYX2TModel,
     UnitYT2UModel,
+    UnitYX2TModel,
 )
 from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenDecoder,
@@ -35,7 +33,7 @@ from seamless_communication.models.unity.unit_tokenizer import (
 
 def remove_consecutive_repeated_ngrams(
     sequence: List[int], min_size: int = 1, max_size: int = 40
-):
+) -> List[int]:
     assert 1 <= min_size <= max_size
     drop_idx = set()  # indices that will be dropped from the sequence
 
@@ -188,7 +186,7 @@ class UnitYGenerator:
             )
         elif input_modality == "text" and self.t2t_generator is None:
             raise ValueError(
-                f"Please set use_text_encoder to True in your model config to encode text."
+                "Please set use_text_encoder to True in your model config to encode text."
             )
         else:
             raise ValueError(f"Unsupported input_modality: {input_modality}")

+ 3 - 2
src/seamless_communication/models/inference/ngram_repeat_block_processor.py → src/seamless_communication/inference/ngram_repeat_block_processor.py

@@ -4,10 +4,11 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from fairseq2.generation import StepProcessor
 from typing import List
-from torch import Tensor
+
 import torch
+from fairseq2.generation import StepProcessor
+from torch import Tensor
 
 
 class NGramRepeatBlockProcessor(StepProcessor):

+ 9 - 9
src/seamless_communication/models/inference/translator.py → src/seamless_communication/inference/translator.py

@@ -3,29 +3,31 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import logging
 from dataclasses import dataclass
 from enum import Enum, auto
 from pathlib import Path
-from torch import Tensor
 from typing import Callable, List, Optional, Tuple, Union, cast
 
-import logging
 import torch
 import torch.nn as nn
-
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import TextTokenizer
 from fairseq2.data.typing import StringLike
-from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
+from fairseq2.generation import SequenceGeneratorOptions, SequenceToTextOutput
 from fairseq2.memory import MemoryBlock
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
+from torch import Tensor
 
+from seamless_communication.inference.generator import (
+    SequenceToUnitOutput,
+    UnitYGenerator,
+)
 from seamless_communication.models.unity import (
     UnitTokenizer,
-    UnitYGenerator,
     UnitYModel,
     UnitYNART2UModel,
     UnitYT2UModel,
@@ -33,9 +35,7 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
-from seamless_communication.models.unity.generator import SequenceToUnitOutput
-from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
-
+from seamless_communication.models.vocoder import Vocoder, load_vocoder_model
 
 logging.basicConfig(
     level=logging.INFO,
@@ -246,7 +246,7 @@ class Translator(nn.Module):
                     audio = audio.unsqueeze(1)
                 elif audio.dim() == 2 and audio.size(0) < audio.size(1):
                     logger.warning(
-                        f"Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
+                        "Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
                     )
                     audio = audio.transpose(0, 1)
 

+ 0 - 14
src/seamless_communication/models/inference/__init__.py

@@ -1,14 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-from seamless_communication.models.inference.ngram_repeat_block_processor import (
-    NGramRepeatBlockProcessor as NGramRepeatBlockProcessor,
-)
-from seamless_communication.models.inference.translator import (
-    BatchedSpeechOutput as BatchedSpeechOutput,
-)
-from seamless_communication.models.inference.translator import Modality as Modality
-from seamless_communication.models.inference.translator import Task as Task
-from seamless_communication.models.inference.translator import Translator as Translator

+ 5 - 5
src/seamless_communication/models/unit_extraction/__init__.py → src/seamless_communication/models/unit_extractor/__init__.py

@@ -4,12 +4,12 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from seamless_communication.models.unit_extraction.unit_extraction import (
-    UnitExtractor as UnitExtractor,
-)
-from seamless_communication.models.unit_extraction.kmeans import (
+from seamless_communication.models.unit_extractor.kmeans import (
     KmeansModel as KmeansModel,
 )
-from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
+from seamless_communication.models.unit_extractor.unit_extractor import (
+    UnitExtractor as UnitExtractor,
+)
+from seamless_communication.models.unit_extractor.wav2vec2_layer_output import (
     Wav2Vec2LayerOutputModel as Wav2Vec2LayerOutputModel,
 )

+ 3 - 3
src/seamless_communication/models/unit_extraction/kmeans.py → src/seamless_communication/models/unit_extractor/kmeans.py

@@ -4,11 +4,11 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-import torch
-from torch import Tensor, nn
 import numpy as np
+import torch
+from fairseq2.assets import download_manager
 from fairseq2.typing import Device
-from seamless_communication.assets import download_manager
+from torch import Tensor, nn
 
 
 class KmeansModel(nn.Module):

+ 13 - 16
src/seamless_communication/models/unit_extraction/unit_extraction.py → src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -4,32 +4,29 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import logging
 from itertools import groupby
 from pathlib import Path
-from torch import Tensor, nn
-from typing import Union
+from typing import List, Union
 
-import logging
 import torch
 import torch.nn.functional as F
-
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater
 from fairseq2.data.audio import AudioDecoder
 from fairseq2.memory import MemoryBlock
-from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.models.sequence import SequenceBatch
-from fairseq2.models.wav2vec2 import Wav2Vec2Model
+from fairseq2.models.wav2vec2 import Wav2Vec2Model, load_wav2vec2_model
+from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
+from torch import Tensor, nn
 
-from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
-    load_wav2vec2_model,
+from seamless_communication.inference import Translator
+from seamless_communication.models.unit_extractor.kmeans import KmeansModel
+from seamless_communication.models.unit_extractor.wav2vec2_layer_output import (
     Wav2Vec2LayerOutputModel,
 )
-from seamless_communication.models.unit_extraction.kmeans import KmeansModel
-from seamless_communication.models.inference import Translator
-from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
-
+from seamless_communication.models.vocoder import Vocoder, load_vocoder_model
 
 logging.basicConfig(
     level=logging.INFO,
@@ -77,7 +74,7 @@ class UnitExtractor(nn.Module):
                 audio = audio.unsqueeze(1)
             elif audio.dim() == 2 and audio.size(0) < audio.size(1):
                 logger.warning(
-                    f"Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
+                    "Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
                 )
                 audio = audio.transpose(0, 1)
 
@@ -93,7 +90,7 @@ class UnitExtractor(nn.Module):
         batch = SequenceBatch(seqs=seqs, padding_mask=padding_mask)
         features = self.model(batch, out_layer_idx).squeeze(0)
         units = self.kmeans_model(features)
-        return units
+        return units  # type: ignore[no-any-return]
 
     @staticmethod
     def resynthesize_audio(
@@ -102,7 +99,7 @@ class UnitExtractor(nn.Module):
         device: Device,
         vocoder_name: str = "vocoder_36langs",
     ) -> Tensor:
-        def reduce_list(lst):
+        def reduce_list(lst: List[Tensor]) -> List[Tensor]:
             return [key for key, _ in groupby(lst)]
 
         reduced_units = reduce_list(units.cpu().tolist())
@@ -112,4 +109,4 @@ class UnitExtractor(nn.Module):
         )
         assert isinstance(vocoder, Vocoder)
         wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
-        return wav
+        return wav  # type: ignore[no-any-return]

+ 10 - 30
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py → src/seamless_communication/models/unit_extractor/wav2vec2_layer_output.py

@@ -3,33 +3,21 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-from fairseq2.nn.padding import PaddingMask
-from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from fairseq2.models.sequence import SequenceBatch
 from fairseq2.models.wav2vec2 import (
-    Wav2Vec2EncoderConfig,
     Wav2Vec2Config,
-    wav2vec2_arch,
-    Wav2Vec2Model,
-    create_wav2vec2_model,
+    Wav2Vec2EncoderConfig,
     Wav2Vec2Frontend,
+    Wav2Vec2Model,
+    wav2vec2_arch,
 )
-from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.models.sequence import SequenceBatch
-
-
-from seamless_communication.assets import asset_store, download_manager
-
-
-import torch
-from typing import Optional
-
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
 from torch import Tensor
-import torch.nn as nn
-
-
-wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
-wav2vec2_arch = wav2vec2_archs.marker
 
 
 def _encoder_xlsr2_1b_v2() -> Wav2Vec2EncoderConfig:
@@ -87,14 +75,6 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:
     )
 
 
-load_wav2vec2_model = Wav2Vec2Loader(
-    asset_store,
-    download_manager,
-    create_wav2vec2_model,
-    wav2vec2_archs,
-)
-
-
 class Wav2Vec2LayerOutputModel(nn.Module):
     encoder_frontend: Wav2Vec2Frontend
     encoder: TransformerEncoder

+ 9 - 12
src/seamless_communication/models/unity/__init__.py

@@ -24,10 +24,10 @@ from seamless_communication.models.unity.length_regulator import (
     HardUpsampling as HardUpsampling,
 )
 from seamless_communication.models.unity.length_regulator import (
-    VariancePredictor as VariancePredictor,
+    VarianceAdaptor as VarianceAdaptor,
 )
 from seamless_communication.models.unity.length_regulator import (
-    VarianceAdaptor as VarianceAdaptor,
+    VariancePredictor as VariancePredictor,
 )
 from seamless_communication.models.unity.loader import UnitYLoader as UnitYLoader
 from seamless_communication.models.unity.loader import (
@@ -40,26 +40,26 @@ from seamless_communication.models.unity.loader import (
     load_unity_unit_tokenizer as load_unity_unit_tokenizer,
 )
 from seamless_communication.models.unity.model import UnitYModel as UnitYModel
-from seamless_communication.models.unity.model import UnitYX2TModel as UnitYX2TModel
-from seamless_communication.models.unity.model import UnitYT2UModel as UnitYT2UModel
 from seamless_communication.models.unity.model import (
     UnitYNART2UModel as UnitYNART2UModel,
 )
 from seamless_communication.models.unity.model import UnitYOutput as UnitYOutput
-from seamless_communication.models.unity.nar_decoder_frontend import (
-    NARDecoderFrontend as NARDecoderFrontend,
-)
+from seamless_communication.models.unity.model import UnitYT2UModel as UnitYT2UModel
+from seamless_communication.models.unity.model import UnitYX2TModel as UnitYX2TModel
 from seamless_communication.models.unity.nar_decoder import (
     NARTransformerDecoder as NARTransformerDecoder,
 )
+from seamless_communication.models.unity.nar_decoder_frontend import (
+    NARDecoderFrontend as NARDecoderFrontend,
+)
 from seamless_communication.models.unity.nar_decoder_layer import (
     NARTransformerDecoderLayer as NARTransformerDecoderLayer,
 )
 from seamless_communication.models.unity.t2u_builder import (
-    UnitYT2UBuilder as UnitYT2UBuilder,
+    UnitYNART2UBuilder as UnitYNART2UBuilder,
 )
 from seamless_communication.models.unity.t2u_builder import (
-    UnitYNART2UBuilder as UnitYNART2UBuilder,
+    UnitYT2UBuilder as UnitYT2UBuilder,
 )
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UConfig as UnitYT2UConfig,
@@ -82,6 +82,3 @@ from seamless_communication.models.unity.unit_tokenizer import (
 from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenizer as UnitTokenizer,
 )
-from seamless_communication.models.unity.generator import (
-    UnitYGenerator as UnitYGenerator,
-)

+ 0 - 1
src/seamless_communication/models/unity/adaptor_block.py

@@ -14,7 +14,6 @@ from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.transformer import (
     AttentionMask,
-    EncoderLayerOutputHook,
     FeedForwardNetwork,
     LayerNormFactory,
     MultiheadAttention,

+ 3 - 5
src/seamless_communication/models/unity/builder.py

@@ -5,9 +5,8 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Union, Optional
+from typing import Optional, Union
 
-from fairseq2.data import VocabularyInfo
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
@@ -24,7 +23,6 @@ from fairseq2.nn.transformer import (
 )
 from fairseq2.typing import DataType, Device
 
-
 from seamless_communication.models.unity.adaptor_block import (
     UnitYConformerAdaptorLayer,
     UnitYEncoderAdaptor,
@@ -32,15 +30,15 @@ from seamless_communication.models.unity.adaptor_block import (
 )
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.t2u_builder import (
-    UnitYT2UBuilder,
     UnitYNART2UBuilder,
+    UnitYT2UBuilder,
     UnitYT2UConfig,
     unity_t2u_archs,
 )
 from seamless_communication.models.wav2vec2_chunk import (
-    wav2vec2_chunk_archs,
     Wav2Vec2ChunkEncoderBuilder,
     Wav2Vec2ChunkEncoderConfig,
+    wav2vec2_chunk_archs,
 )
 
 

+ 6 - 3
src/seamless_communication/models/unity/char_tokenizer.py

@@ -6,7 +6,12 @@
 
 from typing import Optional, Union, final
 
-from fairseq2.assets import AssetStore, AssetDownloadManager, download_manager
+from fairseq2.assets import (
+    AssetDownloadManager,
+    AssetStore,
+    asset_store,
+    download_manager,
+)
 from fairseq2.assets.card import AssetCard
 from fairseq2.data.text import (
     SentencePieceDecoder,
@@ -20,8 +25,6 @@ from fairseq2.data.text import (
 from fairseq2.data.typing import PathLike
 from fairseq2.typing import Device, finaloverride
 
-from seamless_communication.assets import asset_store
-
 
 @final
 class CharTokenizer(TextTokenizer):

+ 5 - 7
src/seamless_communication/models/unity/length_regulator.py

@@ -3,18 +3,16 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-import torch
-
-from torch import Tensor
-from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
-
 from typing import Optional, Tuple
 
-from fairseq2.typing import DataType, Device
-from fairseq2.nn.transformer import create_standard_layer_norm
+import torch
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.projection import Linear
+from fairseq2.nn.transformer import create_standard_layer_norm
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
 
 
 class HardUpsampling(Module):

+ 7 - 8
src/seamless_communication/models/unity/loader.py

@@ -7,10 +7,14 @@
 from typing import Any, Dict, List, Mapping, Union, final
 
 import torch
-from fairseq2.assets import AssetStore, download_manager
+from fairseq2.assets import AssetStore, asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
+from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
+from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader
+from overrides import override as finaloverride
+
 from seamless_communication.models.unity.builder import (
     UnitYConfig,
     create_unity_model,
@@ -19,11 +23,6 @@ from seamless_communication.models.unity.builder import (
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
-from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
-from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader
-from overrides import override as finaloverride
-
-from seamless_communication.assets import asset_store
 
 
 @final
@@ -71,8 +70,8 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
 
-        keys_to_delete.append(f"decoder.char_upsampler.embed_positions._float_tensor")
-        keys_to_delete.append(f"decoder.char_upsampler.embed_tokens_char.weight")
+        keys_to_delete.append("decoder.char_upsampler.embed_positions._float_tensor")
+        keys_to_delete.append("decoder.char_upsampler.embed_tokens_char.weight")
 
         # Delete AlignmentEncoder keys for inference.
         alignment_encoder_keys = [

+ 4 - 7
src/seamless_communication/models/unity/nar_decoder.py

@@ -6,17 +6,14 @@
 
 from typing import Iterable, Optional, Tuple, final
 
-from torch import Tensor
-from torch.nn import Module
-
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
-from fairseq2.nn.transformer import (
-    TransformerNormOrder,
-    create_standard_layer_norm,
-)
+from fairseq2.nn.transformer import TransformerNormOrder, create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Module
+
 from seamless_communication.models.unity.nar_decoder_layer import (
     NARTransformerDecoderLayer,
 )

+ 5 - 9
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -4,11 +4,10 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import math
 from typing import List, Optional, Tuple, final
 
-from torch import Tensor
-from torch.nn import Dropout, Module, Parameter
-
+import torch
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.embedding import Embedding
@@ -17,17 +16,14 @@ from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder
 from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Dropout, Module, Parameter
 
-
+from seamless_communication.models.unity.char_tokenizer import CharTokenizer
 from seamless_communication.models.unity.length_regulator import (
     HardUpsampling,
     VarianceAdaptor,
 )
-from seamless_communication.models.unity.char_tokenizer import CharTokenizer
-
-import math
-import torch
-
 
 SPACE = "▁"
 

+ 4 - 5
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -4,15 +4,14 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Optional, final, Tuple
-
-from torch import Tensor
-from torch.nn import Conv1d, Dropout, Module, ReLU
+from typing import Optional, Tuple, final
 
 from fairseq2.nn.normalization import LayerNorm
-from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
+from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Conv1d, Dropout, Module, ReLU
 
 
 @final

+ 14 - 16
src/seamless_communication/models/unity/t2u_builder.py

@@ -6,9 +6,14 @@
 from dataclasses import dataclass
 from typing import Literal, Optional, Union
 
-from fairseq2.assets import download_manager
+from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import VocabularyInfo
+from fairseq2.models.nllb.loader import NllbTokenizerLoader
+from fairseq2.models.transformer import (
+    TransformerEmbeddingFrontend,
+    TransformerFrontend,
+)
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
@@ -30,25 +35,18 @@ from fairseq2.nn.transformer import (
     create_default_sdpa,
 )
 from fairseq2.typing import DataType, Device
-from fairseq2.models.transformer import (
-    TransformerEmbeddingFrontend,
-    TransformerFrontend,
-)
-from fairseq2.models.nllb.loader import NllbTokenizerLoader
 
-
-from seamless_communication.assets import asset_store
-from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
-from seamless_communication.models.unity.nar_decoder_layer import (
-    NARTransformerDecoderLayer,
-    Conv1dBlock,
-)
-from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
-from seamless_communication.models.unity.model import UnitYT2UModel, UnitYNART2UModel
 from seamless_communication.models.unity.length_regulator import (
-    VariancePredictor,
     VarianceAdaptor,
+    VariancePredictor,
+)
+from seamless_communication.models.unity.model import UnitYNART2UModel, UnitYT2UModel
+from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
+from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
+from seamless_communication.models.unity.nar_decoder_layer import (
+    Conv1dBlock,
+    NARTransformerDecoderLayer,
 )
 
 

+ 1 - 1
src/seamless_communication/models/vocoder/codehifigan.py

@@ -9,8 +9,8 @@ import torch
 import torch.nn as nn
 from torch import Tensor
 
-from seamless_communication.models.vocoder.hifigan import Generator
 from seamless_communication.models.unity import VariancePredictor
+from seamless_communication.models.vocoder.hifigan import Generator
 
 
 class CodeGenerator(Generator):

+ 1 - 1
src/seamless_communication/models/vocoder/loader.py

@@ -6,10 +6,10 @@
 
 from typing import Any, Mapping, final
 
+from fairseq2.assets import asset_store, download_manager
 from fairseq2.models.utils.model_loader import ModelLoader
 from overrides import override as finaloverride
 
-from seamless_communication.assets import asset_store, download_manager
 from seamless_communication.models.vocoder.builder import (
     VocoderConfig,
     create_vocoder_model,

+ 3 - 3
src/seamless_communication/models/wav2vec2_chunk/__init__.py

@@ -4,12 +4,12 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from seamless_communication.models.wav2vec2_chunk.builder import (
-    wav2vec2_chunk_archs as wav2vec2_chunk_archs,
-)
 from seamless_communication.models.wav2vec2_chunk.builder import (
     Wav2Vec2ChunkEncoderBuilder as Wav2Vec2ChunkEncoderBuilder,
 )
 from seamless_communication.models.wav2vec2_chunk.builder import (
     Wav2Vec2ChunkEncoderConfig as Wav2Vec2ChunkEncoderConfig,
 )
+from seamless_communication.models.wav2vec2_chunk.builder import (
+    wav2vec2_chunk_archs as wav2vec2_chunk_archs,
+)

+ 2 - 2
src/seamless_communication/models/wav2vec2_chunk/builder.py

@@ -4,16 +4,16 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from dataclasses import dataclass, asdict
+from dataclasses import asdict, dataclass
 from typing import Literal, Optional
 
 from fairseq2.models.conformer import ConformerConvolution
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2.builder import (
     Wav2Vec2EncoderBuilder,
     Wav2Vec2EncoderConfig,
 )
-from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA
 from fairseq2.typing import DataType, Device
 

+ 2 - 3
src/seamless_communication/models/wav2vec2_chunk/chunk_attention_mask.py

@@ -7,10 +7,9 @@
 from typing import Optional
 
 import torch
-from torch import Tensor
-
-from fairseq2.nn.utils.mask import to_float_mask
 from fairseq2.nn.transformer import AttentionMask, CustomAttentionMask
+from fairseq2.nn.utils.mask import to_float_mask
+from torch import Tensor
 
 
 class ChunkAttentionMaskFactory:

+ 4 - 11
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -6,25 +6,18 @@
 
 from typing import Iterable, Optional, Tuple, final
 
-from torch import Tensor
-from torch.nn import Dropout
-
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
-
-from fairseq2.nn.transformer import (
-    EncoderLayerOutputHook,
-    TransformerEncoder,
-    TransformerEncoderLayer,
-)
+from fairseq2.nn.transformer import TransformerEncoder, TransformerEncoderLayer
+from fairseq2.typing import finaloverride
+from torch import Tensor
+from torch.nn import Dropout
 
 from seamless_communication.models.wav2vec2_chunk.chunk_attention_mask import (
     ChunkAttentionMaskFactory,
 )
 
-from fairseq2.typing import finaloverride
-
 
 @final
 class ChunkTransformerEncoder(TransformerEncoder):

+ 0 - 0
src/seamless_communication/py.typed


+ 1 - 2
tests/common.py

@@ -8,9 +8,8 @@ from contextlib import contextmanager
 from typing import Any, Generator, List, Union
 
 import torch
-from torch import Tensor
-
 from fairseq2.typing import Device
+from torch import Tensor
 
 # The default device that tests should use. Note that pytest can change it based
 # on the provided command line arguments.

+ 2 - 2
tests/conftest.py

@@ -8,10 +8,10 @@ from argparse import ArgumentTypeError
 from typing import cast
 
 import pytest
-import tests.common
-
 from fairseq2.typing import Device
 
+import tests.common
+
 
 def parse_device_arg(value: str) -> Device:
     try:

+ 0 - 0
tests/integration/inference/__init__.py


+ 3 - 2
tests/integration/models/test_translator.py → tests/integration/inference/test_translator.py

@@ -4,11 +4,12 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-import torch
 from typing import Final
 
+import torch
 from fairseq2.typing import Device
-from seamless_communication.models.inference import Translator
+
+from seamless_communication.inference import Translator
 from tests.common import device
 
 # fmt: off

+ 6 - 6
tests/integration/models/test_unit_extraction.py → tests/integration/models/test_unit_extractor.py

@@ -4,22 +4,22 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-import torch
-from torch import tensor
 from typing import Final
 
+import torch
 from fairseq2.typing import Device
-from seamless_communication.models.inference import Translator
-from seamless_communication.models.unit_extraction import UnitExtractor
-from tests.common import assert_equal, device
+from torch import tensor
 
+from seamless_communication.inference import Translator
+from seamless_communication.models.unit_extractor import UnitExtractor
+from tests.common import assert_equal, device
 
 # fmt: off
 REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
 # fmt: on
 
 
-def test_unit_extraction() -> None:
+def test_unit_extractor() -> None:
     model_name = "seamlessM4T_v2_large"
     english_text = "Hello! I hope you're all doing well."