| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 | # Copyright (c) Meta Platforms, Inc. and affiliates# All rights reserved.## This source code is licensed under the license found in the# LICENSE file in the root directory of this source tree.import pytestimport torchfrom seamless_communication.models.unity import UnitTokenizerfrom tests.common import assert_equal, deviceclass TestUnitTokenizer:    def test_init_works(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        assert tokenizer.num_units == 100        assert tokenizer.lang_map == {"eng": 0, "deu": 1, "fra": 2}        assert tokenizer.vocab_info.size == 112    def test_lang_to_index_works(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        assert tokenizer.lang_to_index("eng") == 108        assert tokenizer.lang_to_index("deu") == 109        assert tokenizer.lang_to_index("fra") == 110    def test_lang_to_index_works_nar_decoder(self) -> None:        tokenizer = UnitTokenizer(            num_units=100,            langs=["eng", "deu", "fra"],            model_arch="seamlessM4T_large_v2",        )        assert tokenizer.vocab_info.size == 108        assert tokenizer.lang_to_index("eng") == 104        assert tokenizer.lang_to_index("deu") == 105        assert tokenizer.lang_to_index("fra") == 106    def test_lang_to_index_raises_error_when_lang_is_not_supported(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        with pytest.raises(            ValueError,            match=r"^`lang` must be one of the supported languages, but is 'foo' instead\. Supported languages: eng, deu, fra$",        ):            tokenizer.lang_to_index("foo")    def test_index_to_lang_works(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        assert tokenizer.index_to_lang(108) == "eng"        assert tokenizer.index_to_lang(109) == "deu"        assert tokenizer.index_to_lang(110) == "fra"    def test_index_to_lang_works_nar_decoder(self) -> None:        tokenizer = UnitTokenizer(            num_units=100,            langs=["eng", "deu", "fra"],            model_arch="seamlessM4T_large_v2",        )        assert tokenizer.index_to_lang(104) == "eng"        assert tokenizer.index_to_lang(105) == "deu"        assert tokenizer.index_to_lang(106) == "fra"    def test_vocab_control_symbols(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        assert tokenizer.vocab_info.bos_idx == 0        assert tokenizer.vocab_info.pad_idx == 1        assert tokenizer.vocab_info.eos_idx == 2        assert tokenizer.vocab_info.unk_idx == 3    def test_index_to_lang_raises_error_when_idx_is_out_of_range(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        with pytest.raises(            ValueError,            match=r"^`idx` must correspond to one of the supported language symbol indices \(0 to 2\), but is 1234 instead\.$",        ):            tokenizer.index_to_lang(1234)class TestUnitEncoder:    def test_init_raises_error_when_lang_is_not_supported(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        with pytest.raises(            ValueError,            match=r"^`lang` must be one of the supported languages\, but is 'xyz' instead\. Supported languages: eng, deu, fra$",        ):            tokenizer.create_encoder(lang="xyz", device=device)    def test_call_works(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        prefix = torch.tensor([2, 109], device=device, dtype=torch.int64)        encoder = tokenizer.create_encoder(lang="deu", device=device)        # Empty units.        units = torch.ones((1, 0), device=device, dtype=torch.int64)        assert_equal(encoder(units), prefix.expand(1, -1))        # Batched units.        units = torch.ones((6, 4), device=device, dtype=torch.int64)        assert_equal(            encoder(units), torch.cat([prefix.expand(6, -1), units + 4], dim=1)        )    def test_call_works_nar_decoder(self) -> None:        tokenizer = UnitTokenizer(            num_units=100,            langs=["eng", "deu", "fra"],            model_arch="seamlessM4T_large_v2",        )        encoder = tokenizer.create_encoder(lang="deu", device=device)        # Empty units.        units = torch.ones((1, 0), device=device, dtype=torch.int64)        assert_equal(encoder(units), units)        # Batched units.        units = torch.ones((6, 4), device=device, dtype=torch.int64)        assert_equal(encoder(units), units + 4)    def test_call_works_when_units_have_unks(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        encoder = tokenizer.create_encoder(lang="deu", device=device)        units = torch.ones((6, 4), device=device, dtype=torch.int64)        units[1, 3] = 100        units[2, 1] = 101        token_indices = encoder(units)        assert token_indices[1, 5].item() == tokenizer.vocab_info.unk_idx        assert token_indices[2, 3].item() == tokenizer.vocab_info.unk_idx    def test_call_works_when_units_have_unks_nar_decoder(self) -> None:        tokenizer = UnitTokenizer(            num_units=100,            langs=["eng", "deu", "fra"],            model_arch="seamlessM4T_large_v2",        )        encoder = tokenizer.create_encoder(lang="deu", device=device)        units = torch.ones((6, 4), device=device, dtype=torch.int64)        units[1, 3] = 100        units[2, 1] = 101        token_indices = encoder(units)        assert token_indices[1, 3].item() == tokenizer.vocab_info.unk_idx        assert token_indices[2, 1].item() == tokenizer.vocab_info.unk_idxclass TestUnitDecoder:    def test_call_works(self) -> None:        tokenizer = UnitTokenizer(            num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"        )        encoder = tokenizer.create_encoder(lang="deu", device=device)        decoder = tokenizer.create_decoder()        assert tokenizer.vocab_info.eos_idx is not None        assert tokenizer.vocab_info.pad_idx is not None        units1 = torch.ones((6, 4), device=device, dtype=torch.int64)        encoded_units = encoder(units1)        encoded_units[2, 2] = tokenizer.vocab_info.eos_idx        units2 = decoder(encoded_units)        units1[2, 2] = tokenizer.vocab_info.pad_idx        prefix = torch.tensor([109], device=device, dtype=torch.int64)        assert_equal(torch.cat([prefix.expand(6, -1), units1], dim=1), units2)    def test_call_works_nar_decoder(self) -> None:        tokenizer = UnitTokenizer(            num_units=100,            langs=["eng", "deu", "fra"],            model_arch="seamlessM4T_large_v2",        )        encoder = tokenizer.create_encoder(lang="deu", device=device)        decoder = tokenizer.create_decoder()        assert tokenizer.vocab_info.eos_idx is not None        assert tokenizer.vocab_info.pad_idx is not None        units1 = torch.ones((6, 4), device=device, dtype=torch.int64)        encoded_units = encoder(units1)        encoded_units[2, 2] = tokenizer.vocab_info.eos_idx        units2 = decoder(encoded_units)        units1[2, 2] = tokenizer.vocab_info.pad_idx        assert_equal(units1, units2)
 |