123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # MIT_LICENSE file in the root directory of this source tree.
- import pytest
- import torch
- from seamless_communication.models.unity import UnitTokenizer
- from tests.common import assert_equal, device
- class TestUnitTokenizer:
- def test_init_works(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- assert tokenizer.num_units == 100
- assert tokenizer.lang_map == {"eng": 0, "deu": 1, "fra": 2}
- assert tokenizer.vocab_info.size == 112
- def test_lang_to_index_works(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- assert tokenizer.lang_to_index("eng") == 108
- assert tokenizer.lang_to_index("deu") == 109
- assert tokenizer.lang_to_index("fra") == 110
- def test_lang_to_index_works_nar_decoder(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100,
- langs=["eng", "deu", "fra"],
- model_arch="seamlessM4T_large_v2",
- )
- assert tokenizer.vocab_info.size == 108
- assert tokenizer.lang_to_index("eng") == 104
- assert tokenizer.lang_to_index("deu") == 105
- assert tokenizer.lang_to_index("fra") == 106
- def test_lang_to_index_raises_error_when_lang_is_not_supported(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- with pytest.raises(
- ValueError,
- match=r"^`lang` must be one of the supported languages, but is 'foo' instead\. Supported languages: eng, deu, fra$",
- ):
- tokenizer.lang_to_index("foo")
- def test_index_to_lang_works(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- assert tokenizer.index_to_lang(108) == "eng"
- assert tokenizer.index_to_lang(109) == "deu"
- assert tokenizer.index_to_lang(110) == "fra"
- def test_index_to_lang_works_nar_decoder(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100,
- langs=["eng", "deu", "fra"],
- model_arch="seamlessM4T_large_v2",
- )
- assert tokenizer.index_to_lang(104) == "eng"
- assert tokenizer.index_to_lang(105) == "deu"
- assert tokenizer.index_to_lang(106) == "fra"
- def test_vocab_control_symbols(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- assert tokenizer.vocab_info.bos_idx == 0
- assert tokenizer.vocab_info.pad_idx == 1
- assert tokenizer.vocab_info.eos_idx == 2
- assert tokenizer.vocab_info.unk_idx == 3
- def test_index_to_lang_raises_error_when_idx_is_out_of_range(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- with pytest.raises(
- ValueError,
- match=r"^`idx` must correspond to one of the supported language symbol indices \(0 to 2\), but is 1234 instead\.$",
- ):
- tokenizer.index_to_lang(1234)
- class TestUnitEncoder:
- def test_init_raises_error_when_lang_is_not_supported(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- with pytest.raises(
- ValueError,
- match=r"^`lang` must be one of the supported languages\, but is 'xyz' instead\. Supported languages: eng, deu, fra$",
- ):
- tokenizer.create_encoder(lang="xyz", device=device)
- def test_call_works(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- prefix = torch.tensor([2, 109], device=device, dtype=torch.int64)
- encoder = tokenizer.create_encoder(lang="deu", device=device)
- # Empty units.
- units = torch.ones((1, 0), device=device, dtype=torch.int64)
- assert_equal(encoder(units), prefix.expand(1, -1))
- # Batched units.
- units = torch.ones((6, 4), device=device, dtype=torch.int64)
- assert_equal(
- encoder(units), torch.cat([prefix.expand(6, -1), units + 4], dim=1)
- )
- def test_call_works_nar_decoder(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100,
- langs=["eng", "deu", "fra"],
- model_arch="seamlessM4T_large_v2",
- )
- encoder = tokenizer.create_encoder(lang="deu", device=device)
- # Empty units.
- units = torch.ones((1, 0), device=device, dtype=torch.int64)
- assert_equal(encoder(units), units)
- # Batched units.
- units = torch.ones((6, 4), device=device, dtype=torch.int64)
- assert_equal(encoder(units), units + 4)
- def test_call_works_when_units_have_unks(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- encoder = tokenizer.create_encoder(lang="deu", device=device)
- units = torch.ones((6, 4), device=device, dtype=torch.int64)
- units[1, 3] = 100
- units[2, 1] = 101
- token_indices = encoder(units)
- assert token_indices[1, 5].item() == tokenizer.vocab_info.unk_idx
- assert token_indices[2, 3].item() == tokenizer.vocab_info.unk_idx
- def test_call_works_when_units_have_unks_nar_decoder(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100,
- langs=["eng", "deu", "fra"],
- model_arch="seamlessM4T_large_v2",
- )
- encoder = tokenizer.create_encoder(lang="deu", device=device)
- units = torch.ones((6, 4), device=device, dtype=torch.int64)
- units[1, 3] = 100
- units[2, 1] = 101
- token_indices = encoder(units)
- assert token_indices[1, 3].item() == tokenizer.vocab_info.unk_idx
- assert token_indices[2, 1].item() == tokenizer.vocab_info.unk_idx
- class TestUnitDecoder:
- def test_call_works(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
- )
- encoder = tokenizer.create_encoder(lang="deu", device=device)
- decoder = tokenizer.create_decoder()
- assert tokenizer.vocab_info.eos_idx is not None
- assert tokenizer.vocab_info.pad_idx is not None
- units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
- encoded_units = encoder(units1)
- encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
- units2 = decoder(encoded_units)
- units1[2, 2] = tokenizer.vocab_info.pad_idx
- prefix = torch.tensor([109], device=device, dtype=torch.int64)
- assert_equal(torch.cat([prefix.expand(6, -1), units1], dim=1), units2)
- def test_call_works_nar_decoder(self) -> None:
- tokenizer = UnitTokenizer(
- num_units=100,
- langs=["eng", "deu", "fra"],
- model_arch="seamlessM4T_large_v2",
- )
- encoder = tokenizer.create_encoder(lang="deu", device=device)
- decoder = tokenizer.create_decoder()
- assert tokenizer.vocab_info.eos_idx is not None
- assert tokenizer.vocab_info.pad_idx is not None
- units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
- encoded_units = encoder(units1)
- encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
- units2 = decoder(encoded_units)
- units1[2, 2] = tokenizer.vocab_info.pad_idx
- assert_equal(units1, units2)
|