소스 검색

Translator integration tests and unity unittests. (#63)

* Add translator integ tests, unity unit tests.

* Delete s2t_m4t_v2, x2t_m4t_v2 since we can derive them from seamlessM4T_large_v2.

* Fix bugs in unit_tokenizer, extend unittests for nar_decoder cases.

* Add integ test for unit_extraction.

* Add conftest to be able to run tests on a GPU with pytest --device, specify device in tokenizer.create_encoder() calls in unit_tokenizer tests.
Kaushik Ram Sadagopan 1 년 전
부모
커밋
f4dffda0f8

+ 0 - 10
src/seamless_communication/assets/cards/s2t_m4t_v2.yaml

@@ -1,10 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-name: s2t_m4t_v2
-base: unity_nllb-100
-model_arch: s2t_base_v2
-checkpoint: "file://large_experiments/seamless/ust/elbayadm/multitasking_models/m4t_v2_s2t.pt"

+ 0 - 10
src/seamless_communication/assets/cards/x2t_m4t_v2.yaml

@@ -1,10 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-name: x2t_m4t_v2
-base: unity_nllb-100
-model_arch: x2t_base_v2
-checkpoint: "file://large_experiments/seamless/ust/elbayadm/multitasking_models/m4t_v2_x2t.pt"

+ 1 - 1
src/seamless_communication/models/inference/__init__.py

@@ -4,7 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 from seamless_communication.models.inference.ngram_repeat_block_processor import (
-    NGramRepeatBlockProcessor,
+    NGramRepeatBlockProcessor as NGramRepeatBlockProcessor,
 )
 from seamless_communication.models.inference.translator import (
     BatchedSpeechOutput as BatchedSpeechOutput,

+ 20 - 1
src/seamless_communication/models/inference/translator.py

@@ -9,6 +9,7 @@ from pathlib import Path
 from torch import Tensor
 from typing import Callable, List, Optional, Tuple, Union, cast
 
+import logging
 import torch
 import torch.nn as nn
 
@@ -22,7 +23,6 @@ from fairseq2.memory import MemoryBlock
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 
-
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitYGenerator,
@@ -37,6 +37,14 @@ from seamless_communication.models.unity.generator import SequenceToUnitOutput
 from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
 
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
 class Task(Enum):
     S2ST = auto()
     S2TT = auto()
@@ -231,6 +239,17 @@ class Translator(nn.Module):
                     block = MemoryBlock(fb.read())
                 decoded_audio = self.decode_audio(block)
             else:
+                assert (
+                    audio.dim() <= 2
+                ), "The audio tensor can't be more than 2 dimensions."
+                if audio.dim() == 1:
+                    audio = audio.unsqueeze(1)
+                elif audio.dim() == 2 and audio.size(0) < audio.size(1):
+                    logger.warning(
+                        f"Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
+                    )
+                    audio = audio.transpose(0, 1)
+
                 decoded_audio = {
                     "waveform": audio,
                     "sample_rate": sample_rate,

+ 22 - 4
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -4,14 +4,15 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Union
+from itertools import groupby
 from pathlib import Path
+from torch import Tensor, nn
+from typing import Union
+
+import logging
 import torch
 import torch.nn.functional as F
 
-from itertools import groupby
-from torch import Tensor, nn
-
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater
 from fairseq2.data.audio import AudioDecoder
@@ -30,6 +31,14 @@ from seamless_communication.models.inference import Translator
 from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
 
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
 class UnitExtractor(nn.Module):
     """Unit Extractor which converts raw audio into units."""
 
@@ -63,6 +72,15 @@ class UnitExtractor(nn.Module):
                 block = MemoryBlock(fb.read())
             decoded_audio = self.decode_audio(block)
         else:
+            assert audio.dim() <= 2, "The audio tensor can't be more than 2 dimensions."
+            if audio.dim() == 1:
+                audio = audio.unsqueeze(1)
+            elif audio.dim() == 2 and audio.size(0) < audio.size(1):
+                logger.warning(
+                    f"Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
+                )
+                audio = audio.transpose(0, 1)
+
             decoded_audio = {
                 "waveform": audio,
                 "sample_rate": sample_rate,

+ 0 - 50
src/seamless_communication/models/unity/builder.py

@@ -166,56 +166,6 @@ def _base_v2() -> UnitYConfig:
     )
 
 
-@unity_arch("x2t_base_v2")
-def _x2t_base_v2() -> UnitYConfig:
-    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
-
-    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
-
-    mt_model_config.vocabulary_size = 256102  # NLLB-100
-
-    mt_model_config.max_seq_len = 4096
-
-    return UnitYConfig(
-        model_dim=1024,
-        w2v2_encoder_config=w2v2_chunk_encoder_config,
-        mt_model_config=mt_model_config,
-        t2u_config=None,
-        use_text_encoder=True,
-        use_conformer_adaptor=False,
-        num_adaptor_layers=1,
-        adaptor_kernel_size=8,
-        adaptor_stride=8,
-        adaptor_layer_norm=True,
-        adaptor_dropout_p=0.0,
-    )
-
-
-@unity_arch("s2t_base_v2")
-def _s2t_base_v2() -> UnitYConfig:
-    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
-
-    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
-
-    mt_model_config.vocabulary_size = 256102  # NLLB-100
-
-    mt_model_config.max_seq_len = 4096
-
-    return UnitYConfig(
-        model_dim=1024,
-        w2v2_encoder_config=w2v2_chunk_encoder_config,
-        mt_model_config=mt_model_config,
-        t2u_config=None,
-        use_text_encoder=False,
-        use_conformer_adaptor=False,
-        num_adaptor_layers=1,
-        adaptor_kernel_size=8,
-        adaptor_stride=8,
-        adaptor_layer_norm=True,
-        adaptor_dropout_p=0.0,
-    )
-
-
 class UnitYBuilder:
     """Builds modules of a UnitY model.
 

+ 30 - 17
src/seamless_communication/models/unity/unit_tokenizer.py

@@ -57,9 +57,9 @@ class UnitTokenizer:
         try:
             return (
                 self.num_units
-                + (self.lang_symbol_repititions - 1) * len(self.langs)
+                + (self.lang_symbol_repititions - 1) * (len(self.langs) + 1)
                 + self.lang_map[lang]
-                + 5
+                + 4
             )
         except KeyError:
             langs = ", ".join(self.langs)
@@ -73,8 +73,8 @@ class UnitTokenizer:
         relative_idx = (
             idx
             - self.num_units
-            - (self.lang_symbol_repititions - 1) * len(self.langs)
-            - 5
+            - (self.lang_symbol_repititions - 1) * (len(self.langs) + 1)
+            - 4
         )
 
         if relative_idx < 0 or relative_idx >= len(self.langs):
@@ -92,7 +92,7 @@ class UnitTokenizer:
         :param lang:
             The language of generated token indices.
         """
-        return UnitTokenEncoder(self, lang, device)
+        return UnitTokenEncoder(self, lang, self.is_nar_decoder, device=device)
 
     def create_decoder(self) -> "UnitTokenDecoder":
         """Create a token decoder."""
@@ -106,10 +106,14 @@ class UnitTokenEncoder:
     eos_idx: int
     unk_idx: int
     lang_idx: int
-    prefix_indices: Tensor
+    prefix_indices: Optional[Tensor]
 
     def __init__(
-        self, tokenizer: UnitTokenizer, lang: str, device: Optional[Device] = None
+        self,
+        tokenizer: UnitTokenizer,
+        lang: str,
+        is_nar_decoder: bool,
+        device: Optional[Device] = None,
     ) -> None:
         """
         :param tokenizer:
@@ -125,6 +129,7 @@ class UnitTokenEncoder:
             )
 
         self.tokenizer = tokenizer
+        self.is_nar_decoder = is_nar_decoder
 
         assert tokenizer.vocab_info.eos_idx is not None
         assert tokenizer.vocab_info.unk_idx is not None
@@ -137,10 +142,13 @@ class UnitTokenEncoder:
         if device is None:
             device = Device("cpu")
 
-        # We always start sequences with EOS, followed by the language token.
-        self.prefix_indices = torch.tensor(
-            [self.eos_idx, self.lang_idx], device=device, dtype=torch.int64
-        )
+        if not self.is_nar_decoder:
+            # We always start sequences with EOS, followed by the language token.
+            self.prefix_indices = torch.tensor(
+                [self.eos_idx, self.lang_idx], device=device, dtype=torch.int64
+            )
+        else:
+            self.prefix_indices = None
 
     def __call__(self, units: Tensor) -> Tensor:
         """Encode ``units`` to token indices.
@@ -156,13 +164,18 @@ class UnitTokenEncoder:
         """
         batch_size = units.size(0)
 
-        token_indices = torch.cat(
-            [self.prefix_indices.clone().expand(batch_size, -1), units.detach()], dim=1
-        )
+        if self.prefix_indices is not None:
+            token_indices = torch.cat(
+                [self.prefix_indices.clone().expand(batch_size, -1), units.detach()],
+                dim=1,
+            )
 
-        # Ensure that non-symbol indices larger than `num_units` are replaced
-        # with UNK.
-        seqs = token_indices[:, 2:]
+            # Ensure that non-symbol indices larger than `num_units` are replaced
+            # with UNK.
+            seqs = token_indices[:, 2:]
+        else:
+            token_indices = units.clone().detach()
+            seqs = token_indices
 
         # Add offset for control symbols.
         seqs += 4

+ 9 - 0
tests/__init__.py

@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+
+pytest.register_assert_rewrite("tests.common")

+ 62 - 0
tests/common.py

@@ -0,0 +1,62 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from contextlib import contextmanager
+from typing import Any, Generator, List, Union
+
+import torch
+from torch import Tensor
+
+from fairseq2.typing import Device
+
+# The default device that tests should use. Note that pytest can change it based
+# on the provided command line arguments.
+device = Device("cpu")
+
+
+def assert_close(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
+    """Assert that ``a`` and ``b`` are element-wise equal within a tolerance."""
+    if not isinstance(b, Tensor):
+        b = torch.tensor(b, device=device, dtype=a.dtype)
+
+    torch.testing.assert_close(a, b)  # type: ignore[attr-defined]
+
+
+def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
+    """Assert that ``a`` and ``b`` are element-wise equal."""
+    if not isinstance(b, Tensor):
+        b = torch.tensor(b, device=device, dtype=a.dtype)
+
+    torch.testing.assert_close(a, b, rtol=0, atol=0)  # type: ignore[attr-defined]
+
+
+def has_no_inf(a: Tensor) -> bool:
+    """Return ``True`` if ``a`` has no positive or negative infinite element."""
+    return not torch.any(torch.isinf(a))
+
+
+def has_no_nan(a: Tensor) -> bool:
+    """Return ``True`` if ``a`` has no NaN element."""
+    return not torch.any(torch.isnan(a))
+
+
+@contextmanager
+def tmp_rng_seed(device: Device, seed: int = 0) -> Generator[None, None, None]:
+    """Set a temporary manual RNG seed.
+
+    The RNG is reset to its original state once the block is exited.
+    """
+    device = Device(device)
+
+    if device.type == "cuda":
+        devices = [device]
+    else:
+        devices = []
+
+    with torch.random.fork_rng(devices):
+        torch.manual_seed(seed)
+
+        yield

+ 33 - 0
tests/conftest.py

@@ -0,0 +1,33 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from argparse import ArgumentTypeError
+from typing import cast
+
+import pytest
+import tests.common
+
+from fairseq2.typing import Device
+
+
+def parse_device_arg(value: str) -> Device:
+    try:
+        return Device(value)
+    except RuntimeError:
+        raise ArgumentTypeError(f"'{value}' is not a valid device name.")
+
+
+def pytest_addoption(parser: pytest.Parser) -> None:
+    # fmt: off
+    parser.addoption(
+        "--device", default="cpu", type=parse_device_arg,
+        help="device on which to run tests (default: %(default)s)",
+    )
+    # fmt: on
+
+
+def pytest_sessionstart(session: pytest.Session) -> None:
+    tests.common.device = cast(Device, session.config.getoption("device"))

+ 5 - 0
tests/integration/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 5 - 0
tests/integration/models/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 104 - 0
tests/integration/models/test_translator.py

@@ -0,0 +1,104 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from typing import Final
+
+from fairseq2.typing import Device
+from seamless_communication.models.inference import Translator
+from tests.common import device
+
+# fmt: off
+ENG_SENTENCE:     Final = "On Monday, scientists from the Stanford University School of Medicine announced the invention of a new diagnostic tool that can sort cells by type: a tiny printable chip that can be manufactured using standard inkjet printers for possibly about one U.S. cent each."
+DEU_SENTENCE:     Final = "Am Montag kündigten Wissenschaftler der Stanford University School of Medicine die Erfindung eines neuen Diagnosewerkzeugs an, das Zellen nach Typ sortieren kann: ein winziger druckbarer Chip, der mit Standard-Tintenstrahldruckern für etwa einen US-Cent hergestellt werden kann."
+DEU_SENTENCE_V2:  Final = "Am Montag kündigten Wissenschaftler der Stanford University School of Medicine die Erfindung eines neuen diagnostischen Werkzeugs an, das Zellen nach Typ sortieren kann: ein winziger druckbarer Chip, der mit Standard-Tintenstrahldrucker für möglicherweise etwa einen US-Cent pro Stück hergestellt werden kann."
+# fmt: on
+
+
+def test_seamless_m4t_large_t2tt() -> None:
+    model_name = "seamlessM4T_large"
+    src_lang = "eng"
+    tgt_lang = "deu"
+
+    if device == Device("cpu"):
+        dtype = torch.float32
+    else:
+        dtype = torch.float16
+
+    translator = Translator(model_name, "vocoder_36langs", device, dtype=dtype)
+    text_output, _ = translator.predict(
+        ENG_SENTENCE,
+        "t2tt",
+        tgt_lang,
+        src_lang=src_lang,
+    )
+    assert text_output[0] == DEU_SENTENCE, f"'{text_output[0]}' is not '{DEU_SENTENCE}'"
+
+
+def test_seamless_m4t_v2_large_t2tt() -> None:
+    model_name = "seamlessM4T_v2_large"
+    src_lang = "eng"
+    tgt_lang = "deu"
+
+    if device == Device("cpu"):
+        dtype = torch.float32
+    else:
+        dtype = torch.float16
+
+    translator = Translator(model_name, "vocoder_commercial", device, dtype=dtype)
+    text_output, _ = translator.predict(
+        ENG_SENTENCE,
+        "t2tt",
+        tgt_lang,
+        src_lang=src_lang,
+    )
+    assert (
+        text_output[0] == DEU_SENTENCE_V2
+    ), f"'{text_output[0]}' is not '{DEU_SENTENCE_V2}'"
+
+
+def test_seamless_m4t_v2_large_multiple_tasks() -> None:
+    model_name = "seamlessM4T_v2_large"
+    english_text = "Hello! I hope you're all doing well."
+    ref_spanish_text = "Hola, espero que todos estéis haciendo bien."
+    ref_spanish_asr_text = "Hola, espero que todos estéis haciendo bien."
+
+    if device == Device("cpu"):
+        dtype = torch.float32
+    else:
+        dtype = torch.float16
+
+    translator = Translator(model_name, "vocoder_commercial", device, dtype=dtype)
+
+    # Generate english speech for the english text.
+    _, english_speech_output = translator.predict(
+        english_text,
+        "t2st",
+        "eng",
+        src_lang="eng",
+    )
+    assert english_speech_output is not None
+
+    # Translate english speech to spanish speech.
+    spanish_text_output, spanish_speech_output = translator.predict(
+        english_speech_output.audio_wavs[0][0],
+        "s2st",
+        "spa",
+    )
+    assert spanish_speech_output is not None
+    assert (
+        spanish_text_output[0] == ref_spanish_text
+    ), f"'{spanish_text_output[0]}' is not '{ref_spanish_text}'"
+
+    # Run ASR on the spanish speech.
+    spanish_asr_text_output, _ = translator.predict(
+        spanish_speech_output.audio_wavs[0][0],
+        "asr",
+        "spa",
+    )
+    assert (
+        spanish_asr_text_output[0] == ref_spanish_asr_text
+    ), f"{spanish_asr_text_output[0]} is not {ref_spanish_asr_text}'"

+ 49 - 0
tests/integration/models/test_unit_extraction.py

@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import tensor
+from typing import Final
+
+from fairseq2.typing import Device
+from seamless_communication.models.inference import Translator
+from seamless_communication.models.unit_extraction import UnitExtractor
+from tests.common import assert_equal, device
+
+
+# fmt: off
+REF_ENG_UNITS: Final = [8976, 8299,    0,    0, 9692, 5395,  785,  785, 7805, 6193, 2922, 4806, 3362, 3560, 9007, 8119, 8119,  205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 9631, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 3752, 6756,  963, 5665, 4191, 5205, 5205, 9568, 5092, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469,  296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846,  661, 2225,  944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389,  334, 6334]
+# fmt: on
+
+
+def test_unit_extraction() -> None:
+    model_name = "seamlessM4T_v2_large"
+    english_text = "Hello! I hope you're all doing well."
+
+    if device == Device("cpu"):
+        dtype = torch.float32
+    else:
+        dtype = torch.float16
+
+    translator = Translator(model_name, "vocoder_commercial", device, dtype=dtype)
+    unit_extractor = UnitExtractor(
+        "xlsr2_1b_v2",
+        "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
+        device=device,
+    )
+
+    # Generate english speech for the english text.
+    _, speech_output = translator.predict(
+        english_text,
+        "t2st",
+        "eng",
+        src_lang="eng",
+    )
+    assert speech_output is not None
+
+    units = unit_extractor.predict(speech_output.audio_wavs[0][0], 34)
+
+    assert_equal(units, tensor(REF_ENG_UNITS, device=device, dtype=torch.int64))

+ 5 - 0
tests/unit/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 5 - 0
tests/unit/models/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 5 - 0
tests/unit/models/unity/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 238 - 0
tests/unit/models/unity/test_unity.py

@@ -0,0 +1,238 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+import torch
+
+from seamless_communication.models.unity import UnitTokenizer
+from tests.common import assert_equal, device
+
+
+class TestUnitTokenizer:
+    def test_init_works(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        assert tokenizer.num_units == 100
+
+        assert tokenizer.lang_map == {"eng": 0, "deu": 1, "fra": 2}
+
+        assert tokenizer.vocab_info.size == 112
+
+    def test_lang_to_index_works(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        assert tokenizer.lang_to_index("eng") == 108
+        assert tokenizer.lang_to_index("deu") == 109
+        assert tokenizer.lang_to_index("fra") == 110
+
+    def test_lang_to_index_works_nar_decoder(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100,
+            langs=["eng", "deu", "fra"],
+            model_arch="seamlessM4T_large_v2",
+        )
+        assert tokenizer.vocab_info.size == 108
+
+        assert tokenizer.lang_to_index("eng") == 104
+        assert tokenizer.lang_to_index("deu") == 105
+        assert tokenizer.lang_to_index("fra") == 106
+
+    def test_lang_to_index_raises_error_when_lang_is_not_supported(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        with pytest.raises(
+            ValueError,
+            match=r"^`lang` must be one of the supported languages, but is 'foo' instead\. Supported languages: eng, deu, fra$",
+        ):
+            tokenizer.lang_to_index("foo")
+
+    def test_index_to_lang_works(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        assert tokenizer.index_to_lang(108) == "eng"
+        assert tokenizer.index_to_lang(109) == "deu"
+        assert tokenizer.index_to_lang(110) == "fra"
+
+    def test_index_to_lang_works_nar_decoder(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100,
+            langs=["eng", "deu", "fra"],
+            model_arch="seamlessM4T_large_v2",
+        )
+
+        assert tokenizer.index_to_lang(104) == "eng"
+        assert tokenizer.index_to_lang(105) == "deu"
+        assert tokenizer.index_to_lang(106) == "fra"
+
+    def test_vocab_control_symbols(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        assert tokenizer.vocab_info.bos_idx == 0
+        assert tokenizer.vocab_info.pad_idx == 1
+        assert tokenizer.vocab_info.eos_idx == 2
+        assert tokenizer.vocab_info.unk_idx == 3
+
+    def test_index_to_lang_raises_error_when_idx_is_out_of_range(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        with pytest.raises(
+            ValueError,
+            match=r"^`idx` must correspond to one of the supported language symbol indices \(0 to 2\), but is 1234 instead\.$",
+        ):
+            tokenizer.index_to_lang(1234)
+
+
+class TestUnitEncoder:
+    def test_init_raises_error_when_lang_is_not_supported(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        with pytest.raises(
+            ValueError,
+            match=r"^`lang` must be one of the supported languages\, but is 'xyz' instead\. Supported languages: eng, deu, fra$",
+        ):
+            tokenizer.create_encoder(lang="xyz", device=device)
+
+    def test_call_works(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        prefix = torch.tensor([2, 109], device=device, dtype=torch.int64)
+
+        encoder = tokenizer.create_encoder(lang="deu", device=device)
+
+        # Empty units.
+        units = torch.ones((1, 0), device=device, dtype=torch.int64)
+
+        assert_equal(encoder(units), prefix.expand(1, -1))
+
+        # Batched units.
+        units = torch.ones((6, 4), device=device, dtype=torch.int64)
+
+        assert_equal(
+            encoder(units), torch.cat([prefix.expand(6, -1), units + 4], dim=1)
+        )
+
+    def test_call_works_nar_decoder(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100,
+            langs=["eng", "deu", "fra"],
+            model_arch="seamlessM4T_large_v2",
+        )
+
+        encoder = tokenizer.create_encoder(lang="deu", device=device)
+
+        # Empty units.
+        units = torch.ones((1, 0), device=device, dtype=torch.int64)
+
+        assert_equal(encoder(units), units)
+
+        # Batched units.
+        units = torch.ones((6, 4), device=device, dtype=torch.int64)
+
+        assert_equal(encoder(units), units + 4)
+
+    def test_call_works_when_units_have_unks(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        encoder = tokenizer.create_encoder(lang="deu", device=device)
+
+        units = torch.ones((6, 4), device=device, dtype=torch.int64)
+
+        units[1, 3] = 100
+        units[2, 1] = 101
+
+        token_indices = encoder(units)
+
+        assert token_indices[1, 5].item() == tokenizer.vocab_info.unk_idx
+        assert token_indices[2, 3].item() == tokenizer.vocab_info.unk_idx
+
+    def test_call_works_when_units_have_unks_nar_decoder(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100,
+            langs=["eng", "deu", "fra"],
+            model_arch="seamlessM4T_large_v2",
+        )
+
+        encoder = tokenizer.create_encoder(lang="deu", device=device)
+
+        units = torch.ones((6, 4), device=device, dtype=torch.int64)
+
+        units[1, 3] = 100
+        units[2, 1] = 101
+
+        token_indices = encoder(units)
+
+        assert token_indices[1, 3].item() == tokenizer.vocab_info.unk_idx
+        assert token_indices[2, 1].item() == tokenizer.vocab_info.unk_idx
+
+
+class TestUnitDecoder:
+    def test_call_works(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
+        )
+
+        encoder = tokenizer.create_encoder(lang="deu", device=device)
+        decoder = tokenizer.create_decoder()
+
+        assert tokenizer.vocab_info.eos_idx is not None
+        assert tokenizer.vocab_info.pad_idx is not None
+
+        units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
+
+        encoded_units = encoder(units1)
+
+        encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
+
+        units2 = decoder(encoded_units)
+
+        units1[2, 2] = tokenizer.vocab_info.pad_idx
+
+        prefix = torch.tensor([109], device=device, dtype=torch.int64)
+
+        assert_equal(torch.cat([prefix.expand(6, -1), units1], dim=1), units2)
+
+    def test_call_works_nar_decoder(self) -> None:
+        tokenizer = UnitTokenizer(
+            num_units=100,
+            langs=["eng", "deu", "fra"],
+            model_arch="seamlessM4T_large_v2",
+        )
+
+        encoder = tokenizer.create_encoder(lang="deu", device=device)
+        decoder = tokenizer.create_decoder()
+
+        assert tokenizer.vocab_info.eos_idx is not None
+        assert tokenizer.vocab_info.pad_idx is not None
+
+        units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
+
+        encoded_units = encoder(units1)
+
+        encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
+
+        units2 = decoder(encoded_units)
+
+        units1[2, 2] = tokenizer.vocab_info.pad_idx
+
+        assert_equal(units1, units2)