test_unit_extraction.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import torch
  7. from torch import tensor
  8. from typing import Final
  9. from fairseq2.typing import Device
  10. from seamless_communication.models.inference import Translator
  11. from seamless_communication.models.unit_extraction import UnitExtractor
  12. from tests.common import assert_equal, device
  13. # fmt: off
  14. REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 9007, 8119, 8119, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 9631, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 3752, 6756, 963, 5665, 4191, 5205, 5205, 9568, 5092, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
  15. # fmt: on
  16. def test_unit_extraction() -> None:
  17. model_name = "seamlessM4T_v2_large"
  18. english_text = "Hello! I hope you're all doing well."
  19. if device == Device("cpu"):
  20. dtype = torch.float32
  21. else:
  22. dtype = torch.float16
  23. translator = Translator(model_name, "vocoder_commercial", device, dtype=dtype)
  24. unit_extractor = UnitExtractor(
  25. "xlsr2_1b_v2",
  26. "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
  27. device=device,
  28. )
  29. # Generate english speech for the english text.
  30. _, speech_output = translator.predict(
  31. english_text,
  32. "t2st",
  33. "eng",
  34. src_lang="eng",
  35. )
  36. assert speech_output is not None
  37. units = unit_extractor.predict(speech_output.audio_wavs[0][0], 34)
  38. assert_equal(units, tensor(REF_ENG_UNITS, device=device, dtype=torch.int64))