test_expressivity.py 5.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Final
  7. import torch
  8. from fairseq2.data.audio import AudioDecoderOutput
  9. from torch import tensor
  10. from seamless_communication.inference import Translator
  11. from seamless_communication.inference.pretssel_generator import PretsselGenerator
  12. from seamless_communication.models.unit_extractor import UnitExtractor
  13. from tests.common import (
  14. assert_unit_close,
  15. convert_to_collated_fbank,
  16. device,
  17. )
  18. # fmt: off
  19. REF_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
  20. REF_WAVE_EXTRACTED_UNITS: Final = [8976, 2066, 3800, 2357, 2357, 8080, 9479, 2181, 311, 7241, 5301, 9666, 9925, 940, 9479, 9479, 9479, 3151, 9666, 9925, 2937, 9479, 9479, 3043, 9666, 9189, 9189, 4821, 2937, 2357, 9479, 9479, 9666, 9666, 9666, 9666, 9666, 9666, 9479, 1369, 247, 5025, 5574, 940, 2937, 9479, 9479, 9666, 9666, 9666, 9666, 9666, 5025, 9666, 9666, 9666, 9666, 9666, 9925, 9666, 9479, 9479, 9666, 9666, 9479, 9666, 9479, 9666, 1589, 9666, 9362, 940, 2937, 2937, 9479, 9479, 8063, 9666, 9925, 2937, 9479, 9479, 9666, 9666, 9666, 2130, 4978, 1589, 5574, 5574, 9925, 2937, 9479, 515, 2379, 9666, 9666, 9666, 1589, 4978, 9532, 225, 225, 225, 1251, 225, 3978, 3800, 6343, 1840, 8080, 9666, 9479, 5514, 9666, 6606, 940, 2937, 9479, 9479, 9479, 9666, 9666, 9666, 9666, 9479, 9479, 9666, 9666, 9666, 940, 8080, 9479, 9479, 9479, 9666, 9666, 9666, 9479, 9479, 515, 247, 5025, 5574, 940, 9536, 9479, 9479, 9666, 9666, 9666, 9666, 9666, 1369, 9666, 1653, 4978, 530, 1589, 5574, 940, 940, 9479, 9479, 9479, 9666, 9666, 9925, 9666, 2937, 9479, 8770, 515, 9666, 2130, 5574, 5574, 940, 2937, 9479, 9479, 9479, 8770, 1369, 9580, 1589, 1589, 5574, 940, 2937, 9479, 5634, 9666, 9479, 9202, 1351, 8193, 4660, 4660, 4660, 1463, 1251, 2130, 5574, 1840, 2937, 9479, 9479, 515, 2066, 1653, 7962, 530, 1589, 9666, 940, 940, 9479, 9479, 9666, 9666, 9479, 515, 515, 2720, 8819, 530, 9666, 8063, 940, 2937, 9666, 9666, 9666, 9479, 9666, 9666, 2379, 9925, 2937, 9479, 9479, 1351, 8193, 1589, 9666, 1589, 5574, 940, 2937, 9479, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9479, 1369, 7810, 9666, 1589, 8063, 5574, 940, 940, 2937, 9479, 9479, 9666, 9479, 311, 1369, 1589, 1589, 5574, 940, 2937, 9479, 7848, 3511, 1589, 1795, 5574, 940, 940, 5786, 2003, 8857, 8193, 8193, 1653, 979, 8471, 8471, 1275, 1885, 225, 225, 4199]
  21. # fmt: on
  22. def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> None:
  23. # this model is seeing non-deterministic behavior (fp32 is better)
  24. dtype = torch.float32
  25. audio_dict = example_rate16k_audio
  26. feat = convert_to_collated_fbank(audio_dict, dtype=dtype)
  27. unity_model_name = "seamless_expressivity"
  28. vocoder_model_name = "vocoder_mel"
  29. pretssel_model_name = "pretssel_v1"
  30. target_lang = "fra"
  31. translator = Translator(unity_model_name, None, device, dtype=dtype)
  32. _, speech_output = translator.predict(
  33. feat,
  34. "s2st",
  35. target_lang,
  36. prosody_encoder_input=feat,
  37. )
  38. assert speech_output is not None
  39. units = tensor(speech_output.units[0], device=device, dtype=torch.int64)
  40. # same target units
  41. assert_unit_close(units, REF_UNITS)
  42. pretssel_generator = PretsselGenerator(
  43. unity_model_name,
  44. vocoder_model_name,
  45. pretssel_model_name,
  46. device=device,
  47. dtype=dtype,
  48. )
  49. # same target mel_spectrogram
  50. speech_output = pretssel_generator.predict(
  51. speech_output.units,
  52. tgt_lang=target_lang,
  53. prosody_encoder_input=feat,
  54. )
  55. # UnitExtrator only operates in fp32
  56. waveform = speech_output.audio_wavs[0][0].float()
  57. unit_extractor = UnitExtractor(
  58. "xlsr2_1b_v2",
  59. "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
  60. device=device,
  61. )
  62. units = unit_extractor.predict(waveform, 34)
  63. assert_unit_close(units, REF_WAVE_EXTRACTED_UNITS)