test_pretssel_vocoder.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import cast
  7. import torch
  8. from fairseq2.data.audio import AudioDecoderOutput, WaveformToFbankInput
  9. from seamless_communication.models.vocoder.loader import load_mel_vocoder_model
  10. from tests.common import (
  11. assert_close,
  12. convert_to_collated_fbank,
  13. device,
  14. get_default_dtype,
  15. )
  16. def test_pretssel_vocoder(example_rate16k_audio: AudioDecoderOutput) -> None:
  17. sample_rate = 16_000
  18. dtype = get_default_dtype()
  19. audio_dict = example_rate16k_audio
  20. feat = convert_to_collated_fbank(audio_dict, dtype=dtype)["seqs"][0]
  21. vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=torch.float32)
  22. vocoder.eval()
  23. with torch.inference_mode():
  24. wav_hat = vocoder(feat).view(1, -1)
  25. audio_hat = {"sample_rate": sample_rate, "waveform": wav_hat}
  26. audio_hat_dict = cast(WaveformToFbankInput, audio_hat)
  27. feat_hat = convert_to_collated_fbank(audio_hat_dict, dtype=dtype)["seqs"][0]
  28. assert_close(feat_hat, feat[: feat_hat.shape[0], :], atol=0.0, rtol=5.0)