test_unit_extractor.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Final
  7. import torch
  8. from fairseq2.typing import Device
  9. from torch import tensor
  10. from seamless_communication.inference import Translator
  11. from seamless_communication.models.unit_extractor import UnitExtractor
  12. from tests.common import assert_equal, device, get_default_dtype
  13. # fmt: off
  14. REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
  15. # fmt: on
  16. def test_unit_extractor() -> None:
  17. model_name = "seamlessM4T_v2_large"
  18. english_text = "Hello! I hope you're all doing well."
  19. dtype = get_default_dtype()
  20. translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
  21. unit_extractor = UnitExtractor(
  22. "xlsr2_1b_v2",
  23. "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
  24. device=device,
  25. )
  26. # Generate english speech for the english text.
  27. _, speech_output = translator.predict(
  28. english_text,
  29. "t2st",
  30. "eng",
  31. src_lang="eng",
  32. )
  33. assert speech_output is not None
  34. units = unit_extractor.predict(speech_output.audio_wavs[0][0], 34)
  35. assert_equal(units, tensor(REF_ENG_UNITS, device=device, dtype=torch.int64))