浏览代码

Denoise audio with Demucs and pipeline with Transcriber (#441)

* segmentation class using silero

* pipeline segmenter with transcriber

* transcibe segments progress

* fix arg error

* pipeline segmenter with transcriber

* unit test

* implement pdac algorithm and remove segment_speech logic

* fix segment len

* fix pdac algorithm

* remove unused param

* change min segment size

* add threshold filter

* change default pause len

* demucs class

* pipeline with transcriber

* fix issues

* add dependency

* write process output to file instead of console

* create dataclass DenoisingConfig

* hide denoise logic in transcriber

* fix type error

* unit test

* fix args

* fix dir name

* fix unit test

* remove demucs from setup.py

* nits

* fix cleanup
Alisha Maddy 1 年之前
父节点
当前提交
ca20aa8a1e

+ 0 - 0
src/seamless_communication/denoise/__init__.py


+ 113 - 0
src/seamless_communication/denoise/demucs.py

@@ -0,0 +1,113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import subprocess as sp
+import tempfile
+from typing import Union
+from torch import Tensor
+import torchaudio
+from fairseq2.memory import MemoryBlock
+from dataclasses import dataclass
+from typing import Optional
+import os
+
+SAMPLING_RATE = 16000
+
+@dataclass
+class DenoisingConfig:
+    def __init__(
+            self,
+            filter_width: int = 3,
+            model="htdemucs", 
+            sample_rate=SAMPLING_RATE,
+            two_stems=None,
+            float32=False,
+            int24=False):
+        self.filter_width = filter_width
+        self.model = model
+        self.sample_rate = sample_rate
+        self.two_stems = two_stems
+        self.float32 = float32
+        self.int24 = int24
+
+class Demucs():
+    def __init__(
+            self, 
+            denoise_config: Optional[DenoisingConfig]):
+        self.denoise_config = denoise_config
+        self.temp_files = []
+
+    def run_command_with_temp_file(self, cmd):
+        with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp:
+            self.temp_files.append(temp.name)
+            result = sp.run(cmd, stdout=temp, stderr=temp, text=True)
+            # If there was an error, print the content of the file
+            if result.returncode != 0:
+                temp.seek(0)
+                print(temp.read())
+
+    def cleanup_temp_files(self):
+        for temp_file in self.temp_files:
+            try:
+                os.remove(temp_file)  
+            except Exception as e:
+                print(f"Failed to remove temporary file: {temp_file}. Error: {e}")
+
+    def denoise(self, audio: Union[str, Tensor]):
+
+        if self.denoise_config is None:
+          self.denoise_config = DenoisingConfig()
+
+        if isinstance(audio, Tensor):
+            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
+                self.temp_files.append(temp_wav.name)
+                torchaudio.save(temp_wav.name, audio, self.denoise_config.sample_rate)
+                audio = temp_wav.name
+
+        if not Path(audio).exists():
+            print("Input file does not exist.")
+            return None
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            cmd = ["python3", "-m", "demucs.separate", "-o", temp_dir, "-n", self.denoise_config.model]
+            if self.denoise_config.float32:
+                cmd += ["--float32"]
+            if self.denoise_config.int24:
+                cmd += ["--int24"]
+            if self.denoise_config.two_stems is not None:
+                cmd += [f"--two-stems={self.denoise_config.two_stems}"]
+
+            audio_path = Path(audio)
+            audio_name = audio_path.stem
+            audio = [str(audio)]
+
+            print("Executing command:", " ".join(cmd))
+            self.run_command_with_temp_file(cmd + audio)
+
+            separated_files = list(Path(temp_dir + "/htdemucs/" + audio_name).glob("*vocals.wav*"))
+            
+            if not separated_files:
+                print("Separated vocals file not found.")
+                return None
+
+            waveform, sample_rate = torchaudio.load(separated_files[0])
+
+            if waveform.shape[0] > 1:
+                waveform = waveform.mean(dim=0, keepdim=True)
+            if sample_rate != 16000:
+                resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
+                waveform = resampler(waveform)
+                sample_rate = 16000
+
+            with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav2:
+                torchaudio.save(temp_wav2.name, waveform, sample_rate=sample_rate)
+                block = MemoryBlock(temp_wav2.read())
+
+            self.cleanup_temp_files()
+
+            return block
+        

+ 32 - 11
src/seamless_communication/inference/transcriber.py

@@ -4,7 +4,7 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Tuple, Union
+from typing import Any, Callable, Dict, List, Tuple, Union, Optional
 
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater
@@ -29,6 +29,7 @@ from seamless_communication.models.unity import (
     load_unity_model,
     load_unity_text_tokenizer,
 )
+from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
 
 
 class EncDecAttentionsCollect(AttentionWeightHook):
@@ -272,6 +273,16 @@ class Transcriber(nn.Module):
             step_scores=step_scores,
         )
         return Transcription(stats)
+    
+    def denoise_audio(
+            self, 
+            audio: Union[str, Tensor], 
+            denoise_config: Optional[DenoisingConfig]
+            ) -> Dict:
+        demucs = Demucs(
+            denoise_config=denoise_config)
+        audio = demucs.denoise(audio)
+        return self.decode_audio(audio)
 
     @torch.inference_mode()
     def transcribe(
@@ -280,6 +291,8 @@ class Transcriber(nn.Module):
         src_lang: str,
         filter_width: int = 3,
         sample_rate: int = 16000,
+        denoise: bool = False,
+        denoise_config: Optional[DenoisingConfig] = None,
         **sequence_generator_options: Dict,
     ) -> Transcription:
         """
@@ -295,20 +308,28 @@ class Transcriber(nn.Module):
             Window size to pad weights tensor.
         :params **sequence_generator_options:
             See BeamSearchSeq2SeqGenerator.
+        :params denoise:
+            Whether to denoise the audio.
+        :params denoise_config:
+            Configuration for denoising.
 
         :returns:
             - List of Tokens with timestamps.
         """
-        if isinstance(audio, str):
-            with Path(audio).open("rb") as fb:
-                block = MemoryBlock(fb.read())
-            decoded_audio = self.decode_audio(block)
-        else:
-            decoded_audio = {
-                "waveform": audio,
-                "sample_rate": sample_rate,
-                "format": -1,
-            }
+
+        if denoise:
+            decoded_audio = self.denoise_audio(audio, denoise_config)
+        else:            
+            if isinstance(audio, str):
+                with Path(audio).open("rb") as fb:
+                    block = MemoryBlock(fb.read())
+                decoded_audio = self.decode_audio(block)
+            else:
+                decoded_audio = {
+                    "waveform": audio,
+                    "sample_rate": sample_rate,
+                    "format": -1,
+                }
 
         src = self.convert_to_fbank(decoded_audio)["fbank"]
 

+ 0 - 0
tests/unit/denoise/__init__.py


+ 38 - 0
tests/unit/denoise/test_demucs.py

@@ -0,0 +1,38 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import unittest
+from unittest.mock import patch, MagicMock
+from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
+import torch
+from fairseq2.memory import MemoryBlock
+
+class TestDemucs(unittest.TestCase):
+    def test_init_works(self):
+        config = DenoisingConfig(model="htdemucs", sample_rate=16000)
+        demucs = Demucs(denoise_config=config)
+        self.assertEqual(demucs.denoise_config.model, "htdemucs")
+        self.assertEqual(demucs.denoise_config.sample_rate, 16000)
+
+    @patch("seamless_communication.denoise.demucs.torchaudio.load")
+    @patch("seamless_communication.denoise.demucs.torchaudio.save")
+    @patch("seamless_communication.denoise.demucs.Path")
+    @patch("seamless_communication.denoise.demucs.sp.run")
+    def test_denoise(self, mock_run, mock_path, mock_load):
+
+        mock_run.return_value = MagicMock(returncode=0)
+        mock_load.return_value = (torch.randn(1, 16000), 16000)
+        mock_path.return_value.exists.return_value = True
+        mock_path.return_value.glob.return_value = [MagicMock()]
+        mock_path.return_value.open.return_value.__enter__.return_value.read.return_value = b""
+        config = DenoisingConfig(model="htdemucs", sample_rate=16000)
+        demucs = Demucs(denoise_config=config)
+        result = demucs.denoise(audio=None)
+
+        mock_run.assert_called_once()
+        mock_load.assert_called_once()
+        self.assertIsInstance(result, MemoryBlock)
+