test_demucs.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import unittest
  7. from unittest.mock import patch, MagicMock
  8. from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
  9. import torch
  10. from fairseq2.memory import MemoryBlock
  11. class TestDemucs(unittest.TestCase):
  12. def test_init_works(self):
  13. config = DenoisingConfig(model="htdemucs", sample_rate=16000)
  14. demucs = Demucs(denoise_config=config)
  15. self.assertEqual(demucs.denoise_config.model, "htdemucs")
  16. self.assertEqual(demucs.denoise_config.sample_rate, 16000)
  17. @patch("seamless_communication.denoise.demucs.torchaudio.load")
  18. @patch("seamless_communication.denoise.demucs.Path")
  19. @patch("seamless_communication.denoise.demucs.sp.run")
  20. def test_denoise(self, mock_run, mock_path, mock_load):
  21. mock_run.return_value = MagicMock(returncode=0)
  22. mock_load.return_value = (torch.randn(1, 16000), 16000)
  23. mock_path.return_value.exists.return_value = True
  24. mock_path.return_value.glob.return_value = [MagicMock()]
  25. mock_path.return_value.open.return_value.__enter__.return_value.read.return_value = b""
  26. config = DenoisingConfig(model="htdemucs", sample_rate=16000)
  27. demucs = Demucs(denoise_config=config)
  28. result = demucs.denoise(audio=None)
  29. mock_run.assert_called_once()
  30. mock_load.assert_called_once()
  31. self.assertIsInstance(result, MemoryBlock)