12345678910111213141516171819202122232425262728293031323334353637 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # MIT_LICENSE file in the root directory of this source tree.
- import unittest
- from unittest.mock import patch, MagicMock
- from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
- import torch
- from fairseq2.memory import MemoryBlock
- class TestDemucs(unittest.TestCase):
- def test_init_works(self):
- config = DenoisingConfig(model="htdemucs", sample_rate=16000)
- demucs = Demucs(denoise_config=config)
- self.assertEqual(demucs.denoise_config.model, "htdemucs")
- self.assertEqual(demucs.denoise_config.sample_rate, 16000)
- @patch("seamless_communication.denoise.demucs.torchaudio.load")
- @patch("seamless_communication.denoise.demucs.Path")
- @patch("seamless_communication.denoise.demucs.sp.run")
- def test_denoise(self, mock_run, mock_path, mock_load):
- mock_run.return_value = MagicMock(returncode=0)
- mock_load.return_value = (torch.randn(1, 16000), 16000)
- mock_path.return_value.exists.return_value = True
- mock_path.return_value.glob.return_value = [MagicMock()]
- mock_path.return_value.open.return_value.__enter__.return_value.read.return_value = b""
- config = DenoisingConfig(model="htdemucs", sample_rate=16000)
- demucs = Demucs(denoise_config=config)
- result = demucs.denoise(audio=None)
- mock_run.assert_called_once()
- mock_load.assert_called_once()
- self.assertIsInstance(result, MemoryBlock)
-
|