test_pretssel_vocoder.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import tempfile
  7. from urllib.request import urlretrieve
  8. import torch
  9. import torchaudio
  10. from seamless_communication.models.vocoder.loader import load_mel_vocoder_model
  11. from tests.common import assert_close, device
  12. def test_pretssel_vocoder() -> None:
  13. n_mel_bins = 80
  14. sample_rate = 16_000
  15. vocoder = load_mel_vocoder_model(
  16. "vocoder_mel", device=device, dtype=torch.float32
  17. )
  18. url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
  19. with tempfile.NamedTemporaryFile() as f:
  20. urlretrieve(url, f.name)
  21. _wav, _sr = torchaudio.load(f.name)
  22. wav = torchaudio.sox_effects.apply_effects_tensor(
  23. _wav, _sr, [["rate", f"{sample_rate}"], ["channels", "1"]]
  24. )[0].to(device=device)
  25. feat = torchaudio.compliance.kaldi.fbank(
  26. wav * (2**15), num_mel_bins=n_mel_bins, sample_frequency=sample_rate
  27. )
  28. with torch.no_grad():
  29. wav_hat = vocoder(feat).t()
  30. feat_hat = torchaudio.compliance.kaldi.fbank(
  31. wav_hat * (2**15), num_mel_bins=n_mel_bins, sample_frequency=sample_rate
  32. )
  33. assert_close(feat_hat, feat[: feat_hat.shape[0], :], atol=0.0, rtol=5.0)
  34. if __name__ == "__main__":
  35. test_pretssel_vocoder()