Parcourir la source

Add strip silence to dataloader (#152)

* Add strip silence to dataloader

* Adding explicit default

* remove accidental print
Abinesh Ramakrishnan il y a 1 an
Parent
commit
2a3170ba34
1 fichiers modifiés avec 52 ajouts et 0 suppressions
  1. 52 0
      src/seamless_communication/streaming/dataloaders/s2tt.py

+ 52 - 0
src/seamless_communication/streaming/dataloaders/s2tt.py

@@ -6,6 +6,7 @@
 
 from __future__ import annotations
 
+import logging
 import subprocess
 from argparse import ArgumentParser, Namespace
 from dataclasses import dataclass
@@ -22,6 +23,13 @@ from simuleval.data.dataloader import register_dataloader
 from simuleval.data.dataloader.dataloader import IterableDataloader
 from simuleval.data.dataloader.s2t_dataloader import SpeechToTextDataloader
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
 
 @dataclass
 class SoundFileInfo:
@@ -37,6 +45,34 @@ def count_lines(filename: Path) -> int:
     return int(result.stdout.decode().split()[0]) - 1
 
 
+class SileroVADSilenceRemover:
+    def __init__(self, sample_rate: int = 16000) -> None:
+        self.sample_rate = sample_rate
+        self.model, self.utils = torch.hub.load(
+            repo_or_dir="snakers4/silero-vad",
+            model="silero_vad",
+            # force_reload=True,
+            onnx=False,
+        )
+
+    def __call__(self, sample_list: List[float]) -> List[float]:
+        (
+            get_speech_timestamps,
+            save_audio,
+            read_audio,
+            VADIterator,
+            collect_chunks,
+        ) = self.utils
+        speech_timestamps = get_speech_timestamps(
+            sample_list, self.model, sampling_rate=self.sample_rate
+        )
+        if len(speech_timestamps) == 0:
+            return sample_list
+        speech_start_time = speech_timestamps[0]["start"]
+        speech_end_time = speech_timestamps[-1]["end"]
+        return sample_list[int(speech_start_time) : int(speech_end_time)]
+
+
 @register_dataloader("fairseq2_s2tt")
 class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader):  # type: ignore
     def __init__(self, data_pipeline: DataPipeline, args: Namespace) -> None:
@@ -49,6 +85,12 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
         self.data_pipeline = data_pipeline
         self.data_itr = iter(self.data_pipeline)
         self.cur_index = self.start_index - 1
+        self.silence_remover = None
+        if self.args.strip_silence:
+            logger.warn(
+                "Stripping silence in the beginning and the end of audio with SileroVAD."
+            )
+            self.silence_remover = SileroVADSilenceRemover()
 
     def __iter__(self) -> SimulEvalSpeechToTextDataloader:
         return self
@@ -74,6 +116,10 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
         source: List[float] = (
             self.item["audio"]["data"]["waveform"]["seqs"].squeeze().tolist()
         )
+
+        if self.silence_remover is not None:
+            source = self.silence_remover(source)
+
         return source
 
     def get_target(self, index: Optional[int] = None) -> str:
@@ -166,3 +212,9 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
         parser.add_argument(
             "--tgt-lang", type=str, help="Target language to translate/transcribe into."
         )
+        parser.add_argument(
+            "--strip-silence",
+            action="store_true",
+            default=False,
+            help="Strip silence in the beginning and the end of audio.",
+        )