test_demucs.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import unittest
  7. from unittest.mock import patch, MagicMock
  8. from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
  9. import torch
  10. from fairseq2.memory import MemoryBlock
  11. class TestDemucs(unittest.TestCase):
  12. def test_init_works(self):
  13. config = DenoisingConfig(model="htdemucs", sample_rate=16000)
  14. demucs = Demucs(denoise_config=config)
  15. self.assertEqual(demucs.denoise_config.model, "htdemucs")
  16. self.assertEqual(demucs.denoise_config.sample_rate, 16000)
  17. @patch("seamless_communication.denoise.demucs.torchaudio.load")
  18. @patch("seamless_communication.denoise.demucs.torchaudio.save")
  19. @patch("seamless_communication.denoise.demucs.Path")
  20. @patch("seamless_communication.denoise.demucs.sp.run")
  21. def test_denoise(self, mock_run, mock_path, mock_load):
  22. mock_run.return_value = MagicMock(returncode=0)
  23. mock_load.return_value = (torch.randn(1, 16000), 16000)
  24. mock_path.return_value.exists.return_value = True
  25. mock_path.return_value.glob.return_value = [MagicMock()]
  26. mock_path.return_value.open.return_value.__enter__.return_value.read.return_value = b""
  27. config = DenoisingConfig(model="htdemucs", sample_rate=16000)
  28. demucs = Demucs(denoise_config=config)
  29. result = demucs.denoise(audio=None)
  30. mock_run.assert_called_once()
  31. mock_load.assert_called_once()
  32. self.assertIsInstance(result, MemoryBlock)