|
@@ -0,0 +1,238 @@
|
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
+# All rights reserved.
|
|
|
|
+#
|
|
|
|
+# This source code is licensed under the license found in the
|
|
|
|
+# LICENSE file in the root directory of this source tree.
|
|
|
|
+
|
|
|
|
+import pytest
|
|
|
|
+import torch
|
|
|
|
+
|
|
|
|
+from seamless_communication.models.unity import UnitTokenizer
|
|
|
|
+from tests.common import assert_equal, device
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class TestUnitTokenizer:
|
|
|
|
+ def test_init_works(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ assert tokenizer.num_units == 100
|
|
|
|
+
|
|
|
|
+ assert tokenizer.lang_map == {"eng": 0, "deu": 1, "fra": 2}
|
|
|
|
+
|
|
|
|
+ assert tokenizer.vocab_info.size == 112
|
|
|
|
+
|
|
|
|
+ def test_lang_to_index_works(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ assert tokenizer.lang_to_index("eng") == 108
|
|
|
|
+ assert tokenizer.lang_to_index("deu") == 109
|
|
|
|
+ assert tokenizer.lang_to_index("fra") == 110
|
|
|
|
+
|
|
|
|
+ def test_lang_to_index_works_nar_decoder(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100,
|
|
|
|
+ langs=["eng", "deu", "fra"],
|
|
|
|
+ model_arch="seamlessM4T_large_v2",
|
|
|
|
+ )
|
|
|
|
+ assert tokenizer.vocab_info.size == 108
|
|
|
|
+
|
|
|
|
+ assert tokenizer.lang_to_index("eng") == 104
|
|
|
|
+ assert tokenizer.lang_to_index("deu") == 105
|
|
|
|
+ assert tokenizer.lang_to_index("fra") == 106
|
|
|
|
+
|
|
|
|
+ def test_lang_to_index_raises_error_when_lang_is_not_supported(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ with pytest.raises(
|
|
|
|
+ ValueError,
|
|
|
|
+ match=r"^`lang` must be one of the supported languages, but is 'foo' instead\. Supported languages: eng, deu, fra$",
|
|
|
|
+ ):
|
|
|
|
+ tokenizer.lang_to_index("foo")
|
|
|
|
+
|
|
|
|
+ def test_index_to_lang_works(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ assert tokenizer.index_to_lang(108) == "eng"
|
|
|
|
+ assert tokenizer.index_to_lang(109) == "deu"
|
|
|
|
+ assert tokenizer.index_to_lang(110) == "fra"
|
|
|
|
+
|
|
|
|
+ def test_index_to_lang_works_nar_decoder(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100,
|
|
|
|
+ langs=["eng", "deu", "fra"],
|
|
|
|
+ model_arch="seamlessM4T_large_v2",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ assert tokenizer.index_to_lang(104) == "eng"
|
|
|
|
+ assert tokenizer.index_to_lang(105) == "deu"
|
|
|
|
+ assert tokenizer.index_to_lang(106) == "fra"
|
|
|
|
+
|
|
|
|
+ def test_vocab_control_symbols(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ assert tokenizer.vocab_info.bos_idx == 0
|
|
|
|
+ assert tokenizer.vocab_info.pad_idx == 1
|
|
|
|
+ assert tokenizer.vocab_info.eos_idx == 2
|
|
|
|
+ assert tokenizer.vocab_info.unk_idx == 3
|
|
|
|
+
|
|
|
|
+ def test_index_to_lang_raises_error_when_idx_is_out_of_range(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ with pytest.raises(
|
|
|
|
+ ValueError,
|
|
|
|
+ match=r"^`idx` must correspond to one of the supported language symbol indices \(0 to 2\), but is 1234 instead\.$",
|
|
|
|
+ ):
|
|
|
|
+ tokenizer.index_to_lang(1234)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class TestUnitEncoder:
|
|
|
|
+ def test_init_raises_error_when_lang_is_not_supported(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ with pytest.raises(
|
|
|
|
+ ValueError,
|
|
|
|
+ match=r"^`lang` must be one of the supported languages\, but is 'xyz' instead\. Supported languages: eng, deu, fra$",
|
|
|
|
+ ):
|
|
|
|
+ tokenizer.create_encoder(lang="xyz", device=device)
|
|
|
|
+
|
|
|
|
+ def test_call_works(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ prefix = torch.tensor([2, 109], device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ encoder = tokenizer.create_encoder(lang="deu", device=device)
|
|
|
|
+
|
|
|
|
+ # Empty units.
|
|
|
|
+ units = torch.ones((1, 0), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ assert_equal(encoder(units), prefix.expand(1, -1))
|
|
|
|
+
|
|
|
|
+ # Batched units.
|
|
|
|
+ units = torch.ones((6, 4), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ assert_equal(
|
|
|
|
+ encoder(units), torch.cat([prefix.expand(6, -1), units + 4], dim=1)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def test_call_works_nar_decoder(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100,
|
|
|
|
+ langs=["eng", "deu", "fra"],
|
|
|
|
+ model_arch="seamlessM4T_large_v2",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ encoder = tokenizer.create_encoder(lang="deu", device=device)
|
|
|
|
+
|
|
|
|
+ # Empty units.
|
|
|
|
+ units = torch.ones((1, 0), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ assert_equal(encoder(units), units)
|
|
|
|
+
|
|
|
|
+ # Batched units.
|
|
|
|
+ units = torch.ones((6, 4), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ assert_equal(encoder(units), units + 4)
|
|
|
|
+
|
|
|
|
+ def test_call_works_when_units_have_unks(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ encoder = tokenizer.create_encoder(lang="deu", device=device)
|
|
|
|
+
|
|
|
|
+ units = torch.ones((6, 4), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ units[1, 3] = 100
|
|
|
|
+ units[2, 1] = 101
|
|
|
|
+
|
|
|
|
+ token_indices = encoder(units)
|
|
|
|
+
|
|
|
|
+ assert token_indices[1, 5].item() == tokenizer.vocab_info.unk_idx
|
|
|
|
+ assert token_indices[2, 3].item() == tokenizer.vocab_info.unk_idx
|
|
|
|
+
|
|
|
|
+ def test_call_works_when_units_have_unks_nar_decoder(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100,
|
|
|
|
+ langs=["eng", "deu", "fra"],
|
|
|
|
+ model_arch="seamlessM4T_large_v2",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ encoder = tokenizer.create_encoder(lang="deu", device=device)
|
|
|
|
+
|
|
|
|
+ units = torch.ones((6, 4), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ units[1, 3] = 100
|
|
|
|
+ units[2, 1] = 101
|
|
|
|
+
|
|
|
|
+ token_indices = encoder(units)
|
|
|
|
+
|
|
|
|
+ assert token_indices[1, 3].item() == tokenizer.vocab_info.unk_idx
|
|
|
|
+ assert token_indices[2, 1].item() == tokenizer.vocab_info.unk_idx
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class TestUnitDecoder:
|
|
|
|
+ def test_call_works(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ encoder = tokenizer.create_encoder(lang="deu", device=device)
|
|
|
|
+ decoder = tokenizer.create_decoder()
|
|
|
|
+
|
|
|
|
+ assert tokenizer.vocab_info.eos_idx is not None
|
|
|
|
+ assert tokenizer.vocab_info.pad_idx is not None
|
|
|
|
+
|
|
|
|
+ units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ encoded_units = encoder(units1)
|
|
|
|
+
|
|
|
|
+ encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
|
|
|
|
+
|
|
|
|
+ units2 = decoder(encoded_units)
|
|
|
|
+
|
|
|
|
+ units1[2, 2] = tokenizer.vocab_info.pad_idx
|
|
|
|
+
|
|
|
|
+ prefix = torch.tensor([109], device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ assert_equal(torch.cat([prefix.expand(6, -1), units1], dim=1), units2)
|
|
|
|
+
|
|
|
|
+ def test_call_works_nar_decoder(self) -> None:
|
|
|
|
+ tokenizer = UnitTokenizer(
|
|
|
|
+ num_units=100,
|
|
|
|
+ langs=["eng", "deu", "fra"],
|
|
|
|
+ model_arch="seamlessM4T_large_v2",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ encoder = tokenizer.create_encoder(lang="deu", device=device)
|
|
|
|
+ decoder = tokenizer.create_decoder()
|
|
|
|
+
|
|
|
|
+ assert tokenizer.vocab_info.eos_idx is not None
|
|
|
|
+ assert tokenizer.vocab_info.pad_idx is not None
|
|
|
|
+
|
|
|
|
+ units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
|
|
|
|
+
|
|
|
|
+ encoded_units = encoder(units1)
|
|
|
|
+
|
|
|
|
+ encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
|
|
|
|
+
|
|
|
|
+ units2 = decoder(encoded_units)
|
|
|
|
+
|
|
|
|
+ units1[2, 2] = tokenizer.vocab_info.pad_idx
|
|
|
|
+
|
|
|
|
+ assert_equal(units1, units2)
|