test_unit_extractor.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Final
  7. import torch
  8. from torch import tensor
  9. from fairseq2.typing import Device
  10. from seamless_communication.inference import Translator
  11. from seamless_communication.models.unit_extractor import UnitExtractor
  12. from tests.common import assert_equal
  13. # fmt: off
  14. REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
  15. # fmt: on
  16. def test_unit_extractor() -> None:
  17. model_name = "seamlessM4T_v2_large"
  18. english_text = "Hello! I hope you're all doing well."
  19. # We can't test on the GPU since the output is non-deterministic.
  20. device = Device("cpu")
  21. dtype = torch.float32
  22. translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
  23. # Generate english speech for the english text.
  24. _, speech_output = translator.predict(
  25. english_text,
  26. "t2st",
  27. "eng",
  28. src_lang="eng",
  29. )
  30. assert speech_output is not None
  31. unit_extractor = UnitExtractor(
  32. "xlsr2_1b_v2",
  33. "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
  34. device=device,
  35. dtype=torch.float32,
  36. )
  37. units = unit_extractor.predict(speech_output.audio_wavs[0][0], 34)
  38. assert_equal(units, tensor(REF_ENG_UNITS, device=device, dtype=torch.int64))