123456789101112131415161718192021222324252627282930313233343536373839404142 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from typing import cast
- import torch
- from fairseq2.data.audio import AudioDecoderOutput, WaveformToFbankInput
- from seamless_communication.models.vocoder.loader import load_mel_vocoder_model
- from tests.common import (
- assert_close,
- convert_to_collated_fbank,
- device,
- get_default_dtype,
- )
- def test_pretssel_vocoder(example_rate16k_audio: AudioDecoderOutput) -> None:
- sample_rate = 16_000
- dtype = get_default_dtype()
- audio_dict = example_rate16k_audio
- feat = convert_to_collated_fbank(audio_dict, dtype=dtype)["seqs"][0]
- vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=dtype)
- vocoder.eval()
- with torch.inference_mode():
- wav_hat = vocoder(feat).view(1, -1)
- audio_hat = {"sample_rate": sample_rate, "waveform": wav_hat}
- audio_hat_dict = cast(WaveformToFbankInput, audio_hat)
- feat_hat = convert_to_collated_fbank(audio_hat_dict, dtype=dtype)["seqs"][0]
- assert_close(feat_hat, feat[: feat_hat.shape[0], :], atol=0.0, rtol=5.0)
|