test_unity.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import pytest
  7. import torch
  8. from seamless_communication.models.unity import UnitTokenizer
  9. from tests.common import assert_equal, device
  10. class TestUnitTokenizer:
  11. def test_init_works(self) -> None:
  12. tokenizer = UnitTokenizer(
  13. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  14. )
  15. assert tokenizer.num_units == 100
  16. assert tokenizer.lang_map == {"eng": 0, "deu": 1, "fra": 2}
  17. assert tokenizer.vocab_info.size == 112
  18. def test_lang_to_index_works(self) -> None:
  19. tokenizer = UnitTokenizer(
  20. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  21. )
  22. assert tokenizer.lang_to_index("eng") == 108
  23. assert tokenizer.lang_to_index("deu") == 109
  24. assert tokenizer.lang_to_index("fra") == 110
  25. def test_lang_to_index_works_nar_decoder(self) -> None:
  26. tokenizer = UnitTokenizer(
  27. num_units=100,
  28. langs=["eng", "deu", "fra"],
  29. model_arch="seamlessM4T_large_v2",
  30. )
  31. assert tokenizer.vocab_info.size == 108
  32. assert tokenizer.lang_to_index("eng") == 104
  33. assert tokenizer.lang_to_index("deu") == 105
  34. assert tokenizer.lang_to_index("fra") == 106
  35. def test_lang_to_index_raises_error_when_lang_is_not_supported(self) -> None:
  36. tokenizer = UnitTokenizer(
  37. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  38. )
  39. with pytest.raises(
  40. ValueError,
  41. match=r"^`lang` must be one of the supported languages, but is 'foo' instead\. Supported languages: eng, deu, fra$",
  42. ):
  43. tokenizer.lang_to_index("foo")
  44. def test_index_to_lang_works(self) -> None:
  45. tokenizer = UnitTokenizer(
  46. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  47. )
  48. assert tokenizer.index_to_lang(108) == "eng"
  49. assert tokenizer.index_to_lang(109) == "deu"
  50. assert tokenizer.index_to_lang(110) == "fra"
  51. def test_index_to_lang_works_nar_decoder(self) -> None:
  52. tokenizer = UnitTokenizer(
  53. num_units=100,
  54. langs=["eng", "deu", "fra"],
  55. model_arch="seamlessM4T_large_v2",
  56. )
  57. assert tokenizer.index_to_lang(104) == "eng"
  58. assert tokenizer.index_to_lang(105) == "deu"
  59. assert tokenizer.index_to_lang(106) == "fra"
  60. def test_vocab_control_symbols(self) -> None:
  61. tokenizer = UnitTokenizer(
  62. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  63. )
  64. assert tokenizer.vocab_info.bos_idx == 0
  65. assert tokenizer.vocab_info.pad_idx == 1
  66. assert tokenizer.vocab_info.eos_idx == 2
  67. assert tokenizer.vocab_info.unk_idx == 3
  68. def test_index_to_lang_raises_error_when_idx_is_out_of_range(self) -> None:
  69. tokenizer = UnitTokenizer(
  70. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  71. )
  72. with pytest.raises(
  73. ValueError,
  74. match=r"^`idx` must correspond to one of the supported language symbol indices \(0 to 2\), but is 1234 instead\.$",
  75. ):
  76. tokenizer.index_to_lang(1234)
  77. class TestUnitEncoder:
  78. def test_init_raises_error_when_lang_is_not_supported(self) -> None:
  79. tokenizer = UnitTokenizer(
  80. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  81. )
  82. with pytest.raises(
  83. ValueError,
  84. match=r"^`lang` must be one of the supported languages\, but is 'xyz' instead\. Supported languages: eng, deu, fra$",
  85. ):
  86. tokenizer.create_encoder(lang="xyz", device=device)
  87. def test_call_works(self) -> None:
  88. tokenizer = UnitTokenizer(
  89. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  90. )
  91. prefix = torch.tensor([2, 109], device=device, dtype=torch.int64)
  92. encoder = tokenizer.create_encoder(lang="deu", device=device)
  93. # Empty units.
  94. units = torch.ones((1, 0), device=device, dtype=torch.int64)
  95. assert_equal(encoder(units), prefix.expand(1, -1))
  96. # Batched units.
  97. units = torch.ones((6, 4), device=device, dtype=torch.int64)
  98. assert_equal(
  99. encoder(units), torch.cat([prefix.expand(6, -1), units + 4], dim=1)
  100. )
  101. def test_call_works_nar_decoder(self) -> None:
  102. tokenizer = UnitTokenizer(
  103. num_units=100,
  104. langs=["eng", "deu", "fra"],
  105. model_arch="seamlessM4T_large_v2",
  106. )
  107. encoder = tokenizer.create_encoder(lang="deu", device=device)
  108. # Empty units.
  109. units = torch.ones((1, 0), device=device, dtype=torch.int64)
  110. assert_equal(encoder(units), units)
  111. # Batched units.
  112. units = torch.ones((6, 4), device=device, dtype=torch.int64)
  113. assert_equal(encoder(units), units + 4)
  114. def test_call_works_when_units_have_unks(self) -> None:
  115. tokenizer = UnitTokenizer(
  116. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  117. )
  118. encoder = tokenizer.create_encoder(lang="deu", device=device)
  119. units = torch.ones((6, 4), device=device, dtype=torch.int64)
  120. units[1, 3] = 100
  121. units[2, 1] = 101
  122. token_indices = encoder(units)
  123. assert token_indices[1, 5].item() == tokenizer.vocab_info.unk_idx
  124. assert token_indices[2, 3].item() == tokenizer.vocab_info.unk_idx
  125. def test_call_works_when_units_have_unks_nar_decoder(self) -> None:
  126. tokenizer = UnitTokenizer(
  127. num_units=100,
  128. langs=["eng", "deu", "fra"],
  129. model_arch="seamlessM4T_large_v2",
  130. )
  131. encoder = tokenizer.create_encoder(lang="deu", device=device)
  132. units = torch.ones((6, 4), device=device, dtype=torch.int64)
  133. units[1, 3] = 100
  134. units[2, 1] = 101
  135. token_indices = encoder(units)
  136. assert token_indices[1, 3].item() == tokenizer.vocab_info.unk_idx
  137. assert token_indices[2, 1].item() == tokenizer.vocab_info.unk_idx
  138. class TestUnitDecoder:
  139. def test_call_works(self) -> None:
  140. tokenizer = UnitTokenizer(
  141. num_units=100, langs=["eng", "deu", "fra"], model_arch="seamlessM4T_large"
  142. )
  143. encoder = tokenizer.create_encoder(lang="deu", device=device)
  144. decoder = tokenizer.create_decoder()
  145. assert tokenizer.vocab_info.eos_idx is not None
  146. assert tokenizer.vocab_info.pad_idx is not None
  147. units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
  148. encoded_units = encoder(units1)
  149. encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
  150. units2 = decoder(encoded_units)
  151. units1[2, 2] = tokenizer.vocab_info.pad_idx
  152. prefix = torch.tensor([109], device=device, dtype=torch.int64)
  153. assert_equal(torch.cat([prefix.expand(6, -1), units1], dim=1), units2)
  154. def test_call_works_nar_decoder(self) -> None:
  155. tokenizer = UnitTokenizer(
  156. num_units=100,
  157. langs=["eng", "deu", "fra"],
  158. model_arch="seamlessM4T_large_v2",
  159. )
  160. encoder = tokenizer.create_encoder(lang="deu", device=device)
  161. decoder = tokenizer.create_decoder()
  162. assert tokenizer.vocab_info.eos_idx is not None
  163. assert tokenizer.vocab_info.pad_idx is not None
  164. units1 = torch.ones((6, 4), device=device, dtype=torch.int64)
  165. encoded_units = encoder(units1)
  166. encoded_units[2, 2] = tokenizer.vocab_info.eos_idx
  167. units2 = decoder(encoded_units)
  168. units1[2, 2] = tokenizer.vocab_info.pad_idx
  169. assert_equal(units1, units2)