|
@@ -6,6 +6,7 @@
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
+import logging
|
|
|
import subprocess
|
|
|
from argparse import ArgumentParser, Namespace
|
|
|
from dataclasses import dataclass
|
|
@@ -22,6 +23,13 @@ from simuleval.data.dataloader import register_dataloader
|
|
|
from simuleval.data.dataloader.dataloader import IterableDataloader
|
|
|
from simuleval.data.dataloader.s2t_dataloader import SpeechToTextDataloader
|
|
|
|
|
|
+logging.basicConfig(
|
|
|
+ level=logging.INFO,
|
|
|
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
|
|
+)
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
|
|
|
@dataclass
|
|
|
class SoundFileInfo:
|
|
@@ -37,6 +45,34 @@ def count_lines(filename: Path) -> int:
|
|
|
return int(result.stdout.decode().split()[0]) - 1
|
|
|
|
|
|
|
|
|
+class SileroVADSilenceRemover:
|
|
|
+ def __init__(self, sample_rate: int = 16000) -> None:
|
|
|
+ self.sample_rate = sample_rate
|
|
|
+ self.model, self.utils = torch.hub.load(
|
|
|
+ repo_or_dir="snakers4/silero-vad",
|
|
|
+ model="silero_vad",
|
|
|
+ # force_reload=True,
|
|
|
+ onnx=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ def __call__(self, sample_list: List[float]) -> List[float]:
|
|
|
+ (
|
|
|
+ get_speech_timestamps,
|
|
|
+ save_audio,
|
|
|
+ read_audio,
|
|
|
+ VADIterator,
|
|
|
+ collect_chunks,
|
|
|
+ ) = self.utils
|
|
|
+ speech_timestamps = get_speech_timestamps(
|
|
|
+ sample_list, self.model, sampling_rate=self.sample_rate
|
|
|
+ )
|
|
|
+ if len(speech_timestamps) == 0:
|
|
|
+ return sample_list
|
|
|
+ speech_start_time = speech_timestamps[0]["start"]
|
|
|
+ speech_end_time = speech_timestamps[-1]["end"]
|
|
|
+ return sample_list[int(speech_start_time) : int(speech_end_time)]
|
|
|
+
|
|
|
+
|
|
|
@register_dataloader("fairseq2_s2tt")
|
|
|
class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader): # type: ignore
|
|
|
def __init__(self, data_pipeline: DataPipeline, args: Namespace) -> None:
|
|
@@ -49,6 +85,12 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
|
|
|
self.data_pipeline = data_pipeline
|
|
|
self.data_itr = iter(self.data_pipeline)
|
|
|
self.cur_index = self.start_index - 1
|
|
|
+ self.silence_remover = None
|
|
|
+ if self.args.strip_silence:
|
|
|
+ logger.warn(
|
|
|
+ "Stripping silence in the beginning and the end of audio with SileroVAD."
|
|
|
+ )
|
|
|
+ self.silence_remover = SileroVADSilenceRemover()
|
|
|
|
|
|
def __iter__(self) -> SimulEvalSpeechToTextDataloader:
|
|
|
return self
|
|
@@ -74,6 +116,10 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
|
|
|
source: List[float] = (
|
|
|
self.item["audio"]["data"]["waveform"]["seqs"].squeeze().tolist()
|
|
|
)
|
|
|
+
|
|
|
+ if self.silence_remover is not None:
|
|
|
+ source = self.silence_remover(source)
|
|
|
+
|
|
|
return source
|
|
|
|
|
|
def get_target(self, index: Optional[int] = None) -> str:
|
|
@@ -166,3 +212,9 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
|
|
|
parser.add_argument(
|
|
|
"--tgt-lang", type=str, help="Target language to translate/transcribe into."
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--strip-silence",
|
|
|
+ action="store_true",
|
|
|
+ default=False,
|
|
|
+ help="Strip silence in the beginning and the end of audio.",
|
|
|
+ )
|