Browse Source

Segment audio with Silero VAD and pipeline with Transcriber (#406)

* segmentation class using silero

* pipeline segmenter with transcriber

* transcibe segments progress

* fix arg error

* pipeline segmenter with transcriber

* unit test

* implement pdac algorithm and remove segment_speech logic

* fix segment len

* fix pdac algorithm

* remove unused param

* change min segment size

* add threshold filter

* change default pause len

* new functions for silero_vad.py and fix transcribe()

* remove fbank conversion

* convert all segments to fbank

* fix errors

* fix unit test

* fix demucs test

* fix demucs test
Alisha Maddy 1 year ago
parent
commit
6073b25982

+ 45 - 14
src/seamless_communication/inference/transcriber.py

@@ -19,6 +19,7 @@ from fairseq2.typing import DataType, Device
 
 import numpy as np
 from scipy.signal import medfilt2d
+from argparse import Namespace
 
 import torch
 import torch.nn as nn
@@ -30,6 +31,7 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
 )
 from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
+from seamless_communication.segment.silero_vad import SileroVADSegmenter
 
 
 class EncDecAttentionsCollect(AttentionWeightHook):
@@ -293,6 +295,8 @@ class Transcriber(nn.Module):
         sample_rate: int = 16000,
         denoise: bool = False,
         denoise_config: Optional[DenoisingConfig] = None,
+        chunk_size_sec: int = 20,
+        pause_length_sec: float = 1,
         **sequence_generator_options: Dict,
     ) -> Transcription:
         """
@@ -306,6 +310,12 @@ class Transcriber(nn.Module):
             Sample rate of the audio Tensor.
         :param filter_width:
             Window size to pad weights tensor.
+        :param chunk_size_sec:
+            Length of audio chunks in seconds.
+            For segmenting audio.
+        :param pause_length_sec:
+            Length of pause between audio chunks in seconds.
+            For segmenting audio.
         :params **sequence_generator_options:
             See BeamSearchSeq2SeqGenerator.
         :params denoise:
@@ -321,9 +331,9 @@ class Transcriber(nn.Module):
             decoded_audio = self.denoise_audio(audio, denoise_config)
         else:            
             if isinstance(audio, str):
-                with Path(audio).open("rb") as fb:
-                    block = MemoryBlock(fb.read())
-                decoded_audio = self.decode_audio(block)
+                    with Path(audio).open("rb") as fb:
+                        block = MemoryBlock(fb.read())
+                    decoded_audio = self.decode_audio(block)
             else:
                 decoded_audio = {
                     "waveform": audio,
@@ -331,16 +341,37 @@ class Transcriber(nn.Module):
                     "format": -1,
                 }
 
-        src = self.convert_to_fbank(decoded_audio)["fbank"]
+            length_seconds = (
+                decoded_audio["waveform"].size(0) / decoded_audio["sample_rate"]
+            )
 
-        length_seconds = (
-            decoded_audio["waveform"].size(0) / decoded_audio["sample_rate"]
-        )
+            waveform_2d = decoded_audio.get("waveform")
+            waveform_1d = decoded_audio.get("waveform").view(-1)
+            segmenter = SileroVADSegmenter(
+                sample_rate=sample_rate,
+                chunk_size_sec=chunk_size_sec,
+                pause_length=pause_length_sec,
+            )
 
-        return self.run_inference(
-            src,
-            src_lang,
-            length_seconds,
-            filter_width,
-            sequence_generator_options,
-        )
+            if length_seconds > chunk_size_sec:
+                src_segments = segmenter.segment_long_input(waveform_1d)
+            else:
+                src_segments = [(0, waveform_1d.size(0))]
+
+            transcriptions = []
+            for start, end in src_segments:
+                segment = waveform_2d[start:end, :]
+                src_segment = self.convert_to_fbank(
+                    {"waveform": segment, "sample_rate": decoded_audio.get("sample_rate"), 
+                     "format": decoded_audio.get("format")})["fbank"]
+                length_seconds_segment = segment.size(0) / sample_rate
+                transcription_segment = self.run_inference(
+                    src_segment,
+                    src_lang,
+                    length_seconds_segment,
+                    filter_width,
+                    sequence_generator_options,
+                )
+                transcriptions.append(str(transcription_segment))
+
+            return " ".join(transcriptions)

+ 0 - 0
src/seamless_communication/segment/__init__.py


+ 288 - 0
src/seamless_communication/segment/silero_vad.py

@@ -0,0 +1,288 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+from argparse import Namespace
+import torch
+import typing as tp
+import numpy as np
+import warnings
+
+SAMPLING_RATE = 16000
+
+class SileroVADSegmenter:  # type: ignore
+    def __init__(self, sample_rate = SAMPLING_RATE, chunk_size_sec = 10, pause_length = .5) -> None:
+        self.model, _ = torch.hub.load(
+            repo_or_dir="snakers4/silero-vad",
+            model="silero_vad",
+            force_reload=False,
+            onnx=False,
+        )
+        self.sample_rate = sample_rate
+        self.chunk_size_sec = chunk_size_sec
+        self.pause_length = pause_length
+
+    def segment_long_input(self, audio: torch.Tensor) -> None:
+        """
+        Split long input into chunks using speech timestamps.
+        """
+        max_segment_length_samples = self.chunk_size_sec * self.sample_rate
+        pause_length_samples = self.pause_length * self.sample_rate
+
+        speech_timestamps = self.get_speech_timestamps(
+            audio, self.model, sampling_rate=self.sample_rate
+        )
+
+        segments = []
+        current_segment = []
+
+        for segment in speech_timestamps:
+            start_samples = segment[0]
+            end_samples = segment[1]
+
+            if current_segment and (
+                end_samples - current_segment[0] > max_segment_length_samples
+                or start_samples - current_segment[1] > pause_length_samples
+            ):
+                segments.append(current_segment)
+                current_segment = []
+
+            if not current_segment:
+                current_segment = [start_samples, end_samples]
+            else:
+                current_segment[1] = end_samples
+        if current_segment:
+            segments.append(current_segment) 
+
+        return segments
+   
+    def get_speech_timestamps(
+        self,
+        audio: torch.Tensor,
+        model,
+        sampling_rate: int = SAMPLING_RATE,
+        min_speech_duration_ms: int = 500,
+        window_size_samples: int = 1536,
+    ) -> tp.List[tp.Tuple[int, int]]:
+        """
+        Get speech timestamps based on the speech probabilities.
+        """
+        probs, _ = self.get_speech_probs(
+            audio=audio,
+            model=model,
+            sampling_rate=sampling_rate,
+            window_size_samples=window_size_samples,
+        )
+
+        max_segment_length_samples = self.chunk_size_sec * self.sample_rate
+        min_segment_length_samples = min_speech_duration_ms / 1000 * sampling_rate
+
+        segments = self.pdac(
+            probs=probs,
+            max_segment_length=max_segment_length_samples,
+            min_segment_length=min_segment_length_samples,
+            window_size_samples=window_size_samples,
+        )
+
+        speech_timestamps = [(seg.start, seg.end) for seg in segments]
+
+        return speech_timestamps
+    
+    def recursive_split(
+            self, 
+            sgm, 
+            segments, 
+            max_segment_length, 
+            min_segment_length, 
+            window_size_samples, 
+            threshold
+            ):
+            if sgm.duration < max_segment_length:
+                segments.append(sgm)
+            else:
+                j = 0
+                sorted_indices = np.argsort(sgm.probs)
+                while j < len(sorted_indices):
+                    split_idx = sorted_indices[j]
+                    sgm_a, sgm_b = self.split(
+                      sgm, 
+                      split_idx, 
+                      window_size_samples, 
+                      threshold)
+                    if (
+                        sgm_a.duration > min_segment_length
+                        and sgm_b.duration > min_segment_length
+                    ):
+                        self.recursive_split(
+                          sgm_a,
+                          segments,
+                          max_segment_length,
+                          min_segment_length,
+                          window_size_samples,
+                          threshold)
+                        self.recursive_split(
+                          sgm_b,
+                          segments,
+                          max_segment_length,
+                          min_segment_length,
+                          window_size_samples,
+                          threshold)
+                        break
+                    j += 1
+                else:
+                    if sgm_a.duration > min_segment_length:
+                        self.recursive_split(
+                          sgm_a,
+                          segments,
+                          max_segment_length,
+                          min_segment_length,
+                          window_size_samples,
+                          threshold)
+                    if sgm_b.duration > min_segment_length:
+                        self.recursive_split(
+                          sgm_b,
+                          segments,
+                          max_segment_length,
+                          min_segment_length,
+                          window_size_samples,
+                          threshold)
+
+    def pdac(
+            self,
+            probs: np.array, 
+            max_segment_length: float, 
+            min_segment_length: float, 
+            window_size_samples: float,
+        ) -> tp.List[Segment]:
+        """
+        Recursively splits segments based on speech threshold and duration. 
+        """
+        segments = []
+        sgm = Segment(0, len(probs)*window_size_samples, probs)
+
+        self.recursive_split(sgm, segments, max_segment_length, min_segment_length, window_size_samples, .5)
+        
+        return segments
+
+    def trim(
+            self,
+            sgm: Segment, 
+            threshold: float,
+            window_size_samples: float
+        ) -> Segment:
+        included_indices = np.where(sgm.probs >= threshold)[0]
+        
+        if not len(included_indices):
+            return Segment(sgm.start, sgm.start, np.empty([0]))
+        
+        i = included_indices[0] * window_size_samples
+        j = (included_indices[-1] + 1) * window_size_samples
+
+        sgm = Segment(sgm.start + i, 
+        sgm.start + j, 
+        sgm.probs[included_indices[0]:included_indices[-1]+1])
+
+        return sgm
+
+    def split(
+            self,
+            sgm: Segment, 
+            split_idx: int, 
+            window_size_samples: float,
+            threshold: float
+        ) -> tp.Tuple[Segment, Segment]:
+        """
+        Splits segment into two segments based on the split index.
+        """
+        probs_a = sgm.probs[:split_idx]
+        sgm_a = Segment(sgm.start, sgm.start + (len(probs_a)*window_size_samples), probs_a)
+
+        probs_b = sgm.probs[split_idx + 1 :]
+        sgm_b = Segment(sgm_a.end + 1, sgm.end, probs_b)
+
+        sgm_a = self.trim(sgm_a, threshold, window_size_samples)
+        sgm_b = self.trim(sgm_b, threshold, window_size_samples)
+
+        return sgm_a, sgm_b
+    
+    @staticmethod
+    def resample_audio(wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        """
+        Resample audio to the model's sample rate.
+        """
+        assert sample_rate <= sample_rate
+        if sample_rate == sample_rate:
+            return wav
+
+        tgt_frames = wav.shape[-1] * sample_rate // sample_rate
+        coeff = sample_rate / sample_rate
+        indices = (torch.arange(tgt_frames) * coeff).to(torch.int32)
+        return wav[:, indices]
+    
+    @staticmethod
+    def get_speech_probs(
+        audio: torch.Tensor,
+        model,
+        sampling_rate: int = SAMPLING_RATE,
+        window_size_samples: int = 1536,
+    ) -> tp.Tuple[np.ndarray, int]:
+        """
+        Get a list of speech probabilities computed with sliding window over the audio using the model.
+        """
+        if not torch.is_tensor(audio):
+            try:
+                audio = torch.Tensor(audio)
+            except:
+                raise TypeError("Audio cannot be casted to tensor. Cast it manually")
+
+        if len(audio.shape) > 1:
+            for _ in range(audio.ndim):  # trying to squeeze empty dimensions
+                audio = audio.squeeze(0)
+            assert (
+                audio.ndim == 1
+            ), "More than one dimension in audio. Are you trying to process audio with 2 channels?"
+
+        audio = SileroVADSegmenter.resample_audio(audio, sampling_rate)
+
+        if sampling_rate == 8000 and window_size_samples > 768:
+            warnings.warn(
+                """window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 
+                256, 512 or 768 for 8000 sample rate!"""
+            )
+        if window_size_samples not in [256, 512, 768, 1024, 1536]:
+            warnings.warn(
+                """Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 
+                16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate"""
+            )
+
+        model.reset_states()
+
+        audio_length_samples = len(audio)
+
+        speech_probs = []
+        for current_start_sample in range(0, audio_length_samples, window_size_samples):
+            chunk = audio[
+                current_start_sample : current_start_sample + window_size_samples
+            ]
+            if len(chunk) < window_size_samples:
+                chunk = torch.nn.functional.pad(
+                    chunk, (0, int(window_size_samples - len(chunk)))
+                )
+            if next(model.parameters()).is_cuda:
+                chunk = chunk.cuda()
+            speech_prob = model(chunk, sampling_rate).item()
+            speech_probs.append(speech_prob)
+
+        return np.array(speech_probs), audio_length_samples
+    
+class Segment:
+    def __init__(self, start: int, end: int, probs: np.ndarray):
+        self.start = start
+        self.end = end
+        self.probs = probs
+        self.duration = float(end - start)
+    

+ 0 - 1
tests/unit/denoise/test_demucs.py

@@ -18,7 +18,6 @@ class TestDemucs(unittest.TestCase):
         self.assertEqual(demucs.denoise_config.sample_rate, 16000)
 
     @patch("seamless_communication.denoise.demucs.torchaudio.load")
-    @patch("seamless_communication.denoise.demucs.torchaudio.save")
     @patch("seamless_communication.denoise.demucs.Path")
     @patch("seamless_communication.denoise.demucs.sp.run")
     def test_denoise(self, mock_run, mock_path, mock_load):

+ 0 - 0
tests/unit/segment/__init__.py


+ 59 - 0
tests/unit/segment/test_silero_vad.py

@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import unittest
+from argparse import Namespace
+from unittest.mock import Mock
+from seamless_communication.segment.silero_vad import SileroVADSegmenter, Segment
+import numpy as np
+
+
+class TestSileroVADSegmenter(unittest.TestCase):
+    def test_init_works(self):
+        segmenter = SileroVADSegmenter(
+          sample_rate=16000, 
+          chunk_size_sec=10, 
+          pause_length=0.5)
+        self.assertEqual(segmenter.sample_rate, 16000)
+        self.assertEqual(segmenter.chunk_size_sec, 10)
+        self.assertEqual(segmenter.pause_length, 0.5)
+
+
+    def test_segment_long_input(self):
+        self.segmenter = SileroVADSegmenter(
+          sample_rate=16000, 
+          chunk_size_sec=10, 
+          pause_length=0.5)
+        self.segmenter.get_speech_timestamps = Mock(
+          return_value=[{0: 0, 1: 10000}, 
+          {0: 20000, 1: 30000}])
+        segments = self.segmenter.segment_long_input(audio=None)
+        expected_segments = [[0, 10000], [20000, 30000]]
+        self.assertEqual(segments, expected_segments)
+
+
+    def test_recursive_split(self):
+        segmenter = SileroVADSegmenter(
+          sample_rate=16000, 
+          chunk_size_sec=10,
+          pause_length=0.5)
+        sgm = Segment(0, 10000, np.random.rand(10000))
+        segments = []
+        max_segment_length = 5000
+        min_segment_length = 1000
+        window_size_samples = 100
+        threshold = .5
+
+        segmenter.recursive_split(
+          sgm, 
+          segments, 
+          max_segment_length, 
+          min_segment_length, 
+          window_size_samples, 
+          threshold)
+
+        assert all([seg.duration < max_segment_length for seg in segments])
+        assert all([seg.duration > min_segment_length for seg in segments])