test_unity2_aligner.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Final
  7. import torch
  8. from fairseq2.typing import Device
  9. from torch import tensor
  10. from tests.common import assert_equal, device
  11. from seamless_communication.models.aligner.alignment_extractor import AlignmentExtractor
  12. from fairseq2.data.audio import (
  13. AudioDecoder,
  14. AudioDecoderOutput
  15. )
  16. from fairseq2.memory import MemoryBlock
  17. from urllib.request import urlretrieve
  18. import tempfile
  19. from tests.common import assert_equal, device
  20. REF_TEXT = "the examination and testimony of the experts enabled the commision to conclude that five shots may have been fired"
  21. REF_DURATIONS: Final = [[ 1, 1, 2, 1, 1, 5, 5, 6, 4, 3, 2, 3, 4, 4, 2, 2, 2, 1,
  22. 1, 1, 3, 3, 3, 4, 3, 3, 4, 3, 4, 3, 2, 2, 1, 1, 1, 1,
  23. 2, 4, 6, 5, 4, 3, 4, 5, 5, 16, 6, 3, 5, 5, 3, 3, 1, 2,
  24. 1, 1, 1, 2, 3, 2, 3, 1, 3, 3, 3, 2, 2, 4, 2, 2, 2, 3,
  25. 2, 4, 5, 4, 5, 8, 3, 17, 2, 2, 3, 2, 5, 4, 6, 3, 1, 1,
  26. 4, 4, 3, 5, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 1,
  27. 2, 6, 4, 5, 9, 5, 1, 12]]
  28. def test_aligner(example_rate16k_audio: AudioDecoderOutput) -> None:
  29. aligner_name = "nar_t2u_aligner"
  30. unit_extractor_name = "xlsr2_1b_v2"
  31. unit_extractor_output_layer_n = 35
  32. unit_extractor_kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
  33. extractor = AlignmentExtractor(
  34. aligner_name,
  35. unit_extractor_name,
  36. unit_extractor_output_layer_n,
  37. unit_extractor_kmeans_uri,
  38. device=device
  39. )
  40. audio = example_rate16k_audio["waveform"].mean(1) # averaging mono to get [Time] shape required by aligner
  41. alignment_durations, _, _ = extractor.extract_alignment(audio, REF_TEXT, plot=False, add_trailing_silence=True)
  42. assert_equal(alignment_durations, tensor(REF_DURATIONS, device=device, dtype=torch.int64))