test_unity2_aligner.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. from typing import Final
  7. import torch
  8. from torch import tensor
  9. from fairseq2.data.audio import AudioDecoderOutput
  10. from seamless_communication.models.aligner.alignment_extractor import AlignmentExtractor
  11. from tests.common import assert_equal, device, get_default_dtype
  12. REF_TEXT = "the examination and testimony of the experts enabled the commision to conclude that five shots may have been fired"
  13. # fmt: off
  14. REF_DURATIONS_FP16: Final = [[ 1, 1, 2, 1, 1, 5, 5, 6, 4, 3, 2, 3, 4, 4, 2, 2, 2, 1,
  15. 1, 1, 3, 3, 3, 4, 3, 3, 3, 4, 4, 3, 2, 2, 1, 1, 1, 1,
  16. 2, 4, 6, 5, 4, 3, 4, 5, 5, 16, 6, 3, 5, 5, 3, 3, 1, 2,
  17. 1, 1, 1, 2, 3, 2, 3, 1, 3, 3, 3, 2, 2, 4, 2, 2, 2, 3,
  18. 2, 4, 5, 4, 5, 8, 3, 17, 2, 2, 3, 2, 5, 4, 6, 3, 1, 1,
  19. 4, 4, 3, 5, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 1,
  20. 2, 6, 4, 5, 9, 5, 1, 12]]
  21. # fmt: on
  22. # fmt: off
  23. REF_DURATIONS_FP32: Final = [[ 1, 1, 2, 1, 1, 5, 5, 6, 4, 3, 2, 3, 4, 4, 2, 2, 2, 1,
  24. 1, 1, 3, 3, 3, 4, 3, 3, 4, 3, 4, 3, 2, 2, 1, 1, 1, 1,
  25. 2, 4, 6, 5, 4, 3, 4, 5, 5, 16, 6, 3, 5, 5, 3, 3, 1, 2,
  26. 1, 1, 1, 2, 3, 2, 3, 1, 3, 3, 3, 2, 2, 4, 2, 2, 2, 3,
  27. 2, 4, 5, 4, 5, 8, 3, 17, 2, 2, 3, 2, 5, 4, 6, 3, 1, 1,
  28. 4, 4, 3, 5, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 1,
  29. 2, 6, 4, 5, 9, 5, 1, 12]]
  30. # fmt: on
  31. def test_aligner(example_rate16k_audio: AudioDecoderOutput) -> None:
  32. aligner_name = "nar_t2u_aligner"
  33. unit_extractor_name = "xlsr2_1b_v2"
  34. unit_extractor_output_layer_n = 35
  35. unit_extractor_kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
  36. dtype = get_default_dtype()
  37. if dtype == torch.float32:
  38. ref_tensor = REF_DURATIONS_FP32
  39. else:
  40. ref_tensor = REF_DURATIONS_FP16
  41. audio = example_rate16k_audio["waveform"].mean(
  42. 1
  43. ) # averaging mono to get [Time] shape required by aligner
  44. extractor = AlignmentExtractor(
  45. aligner_name,
  46. unit_extractor_name,
  47. unit_extractor_output_layer_n,
  48. unit_extractor_kmeans_uri,
  49. device=device,
  50. dtype=dtype,
  51. )
  52. alignment_durations, _, _ = extractor.extract_alignment(
  53. audio, REF_TEXT, plot=False, add_trailing_silence=True
  54. )
  55. assert_equal(
  56. alignment_durations, tensor(ref_tensor, device=device, dtype=torch.int64)
  57. )