Ver código fonte

Silero VAD Agent (#120)

* Silero VAD Agent and pipeline changes

* updating logger

* add prefix_tgt_lang arg for when not using dataloader

---------

Co-authored-by: Anna Sun <13106449+annasun28@users.noreply.github.com>
Abinesh Ramakrishnan 1 ano atrás
pai
commit
b1027e9858

+ 14 - 5
src/seamless_communication/streaming/agents/mma_m4t_s2t.py

@@ -4,18 +4,17 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from seamless_communication.streaming.agents.detokenizer import (
-    DetokenizerAgent,
+from seamless_communication.streaming.agents.detokenizer import DetokenizerAgent
+from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
+    OfflineWav2VecBertEncoderAgent,
 )
 from seamless_communication.streaming.agents.online_feature_extractor import (
     OnlineFeatureExtractorAgent,
 )
-from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
-    OfflineWav2VecBertEncoderAgent,
-)
 from seamless_communication.streaming.agents.online_text_decoder import (
     MMASpeechToTextDecoderAgent,
 )
+from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
 from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
 from simuleval.utils import entrypoint
 
@@ -37,3 +36,13 @@ class MonotonicM4TS2TAgent(UnitYAgentPipeline):
         OfflineWav2VecBertEncoderAgent,
         MMASpeechToTextDecoderAgent,
     ]
+
+
+class MonotonicM4TS2TVADAgent(UnitYAgentPipeline):
+    pipeline = [
+        SileroVADAgent,
+        OnlineFeatureExtractorAgent,
+        OfflineWav2VecBertEncoderAgent,
+        MMASpeechToTextDecoderAgent,
+        DetokenizerAgent,
+    ]

+ 11 - 8
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -71,13 +71,15 @@ class OnlineTextDecoderAgent(GenericAgent):
         self.max_len_b: int = args.max_len_b
         self.max_consecutive_writes = self.args.max_consecutive_write
         self.min_starting_wait = args.min_starting_wait
-        self.min_starting_wait_reset = args.min_starting_wait_reset
         self.no_early_stop = args.no_early_stop
 
         self.device = args.device
         self.dtype = args.dtype
         self.eos_idx = text_tokenizer.vocab_info.eos_idx
-        token_encoder = text_tokenizer.create_encoder(lang=args.tgt_lang, mode="target")
+        if hasattr(args, "tgt_lang") and hasattr(args, "prefix_tgt_lang"):
+            assert args.tgt_lang == args.prefix_tgt_lang
+        tgt_lang = getattr(args, "tgt_lang", None) or getattr(args, "prefix_tgt_lang", None)
+        token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
         prefix_indices = token_encoder.prefix_indices
         assert prefix_indices is not None
         self.prefix_indices: List[int] = prefix_indices.tolist()
@@ -111,7 +113,7 @@ class OnlineTextDecoderAgent(GenericAgent):
         parser.add_argument(
             "--min-starting-wait",
             type=int,
-            default=12,
+            default=1,
             help="Minimal starting waiting source steps",
         )
         parser.add_argument(
@@ -123,7 +125,12 @@ class OnlineTextDecoderAgent(GenericAgent):
         parser.add_argument(
             "--no-early-stop",
             action="store_true",
-            default=True,
+            default=False,
+        )
+        parser.add_argument(
+            "--prefix-tgt-lang",
+            type=str,
+            default=None,
         )
 
     def policy(self, states: DecoderAgentStates) -> Action:
@@ -285,10 +292,6 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             ):
                 if prob == 1.0:
                     pred_indices = []
-                if states.source_len < self.min_starting_wait_reset:
-                    pred_indices = []
-                    if len(states.target_indices) < 3:
-                        states.target_indices = []
                 break
             if (
                 finished

+ 302 - 0
src/seamless_communication/streaming/agents/silero_vad.py

@@ -0,0 +1,302 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import logging
+import queue
+import time
+from argparse import ArgumentParser, Namespace
+from os import SEEK_END
+from typing import Any, List, Optional, Union
+
+import numpy as np
+import torch
+from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
+from simuleval.agents import AgentStates, SpeechToSpeechAgent
+from simuleval.agents.actions import Action, ReadAction, WriteAction
+from simuleval.data.segments import EmptySegment, Segment, SpeechSegment
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+SPEECH_PROB_THRESHOLD = 0.6
+
+
+class SileroVADStates(EarlyStoppingMixin, AgentStates):
+    def __init__(self, args: Namespace) -> None:
+        self.model, utils = torch.hub.load(
+            repo_or_dir="snakers4/silero-vad",
+            model="silero_vad",
+            force_reload=False,
+            onnx=False,
+        )
+
+        (
+            self.get_speech_timestamps,
+            self.save_audio,
+            self.read_audio,
+            self.VADIterator,
+            self.collect_chunks,
+        ) = utils
+        self.silence_limit_ms = args.silence_limit_ms
+        self.speech_soft_limit_ms = args.speech_soft_limit_ms
+        self.window_size_samples = args.window_size_samples
+        self.chunk_size_samples = args.chunk_size_samples
+        self.sample_rate = args.sample_rate
+        self.debug = args.debug
+        self.test_input_segments_wav = None
+        self.debug_log(args)
+        self.input_queue: queue.Queue[Segment] = queue.Queue()
+        self.next_input_queue: queue.Queue[Segment] = queue.Queue()
+        super().__init__()
+
+    def clear_queues(self) -> None:
+        while not self.input_queue.empty():
+            self.input_queue.get_nowait()
+            self.input_queue.task_done()
+        # move everything from next_input_queue to input_queue
+        while not self.next_input_queue.empty():
+            chunk = self.next_input_queue.get_nowait()
+            self.next_input_queue.task_done()
+            self.input_queue.put_nowait(chunk)
+
+    def reset(self) -> None:
+        super().reset()
+        # TODO: in seamless_server, report latency for each new segment
+        self.first_input_ts: Optional[float] = None
+        self.silence_acc_ms = 0
+        self.speech_acc_ms = 0
+        self.input_chunk: np.ndarray[Any, np.dtype[np.int16]] = np.empty(
+            0, dtype=np.int16
+        )
+        self.is_fresh_state = True
+        self.clear_queues()
+        self.model.reset_states()
+
+    def reset_early(self) -> None:
+        """
+        Don't reset state before EOS
+        """
+        pass
+
+    def get_speech_prob_from_np_float32(
+        self, segment: np.ndarray[Any, np.dtype[np.float32]]
+    ) -> List[Any]:
+        t = torch.from_numpy(segment)
+        speech_probs = []
+        for i in range(0, len(t), self.window_size_samples):
+            chunk = t[i : i + self.window_size_samples]
+            if len(chunk) < self.window_size_samples:
+                break
+            speech_prob = self.model(chunk, self.sample_rate).item()
+            speech_probs.append(speech_prob)
+        return speech_probs
+
+    def debug_log(self, m: Any) -> None:
+        if self.debug:
+            logger.info(m)
+
+    def process_speech(
+        self,
+        segment: Union[np.ndarray[Any, np.dtype[np.float32]], Segment],
+        tgt_lang: Optional[str] = None,
+    ) -> None:
+        """
+        Process a full or partial speech chunk
+        """
+        queue = self.input_queue
+        if self.source_finished:
+            # current source is finished, but next speech starts to come in already
+            self.debug_log("use next_input_queue")
+            queue = self.next_input_queue
+
+        # NOTE: we don't reset silence_acc_ms here so that once an utterance
+        # becomes longer (accumulating more silence), it has a higher chance
+        # of being segmented.
+        self.silence_acc_ms = self.silence_acc_ms // 2
+
+        if self.first_input_ts is None:
+            self.first_input_ts = time.time() * 1000
+
+        while len(segment) > 0:
+            # add chunks to states.buffer
+            i = self.chunk_size_samples - len(self.input_chunk)
+            self.input_chunk = np.concatenate((self.input_chunk, segment[:i]))
+            segment = segment[i:]
+            self.is_fresh_state = False
+            if len(self.input_chunk) == self.chunk_size_samples:
+                queue.put_nowait(
+                    SpeechSegment(
+                        content=self.input_chunk, finished=False, tgt_lang=tgt_lang
+                    )
+                )
+                self.input_chunk = np.empty(0, dtype=np.int16)
+
+    def check_silence_acc(self, tgt_lang: Optional[str] = None) -> None:
+        silence_limit_ms = self.silence_limit_ms
+        if self.speech_acc_ms >= self.speech_soft_limit_ms:
+            self.debug_log("increase speech threshold")
+            silence_limit_ms = self.silence_limit_ms // 2
+        self.debug_log(f"silence_acc_ms: {self.silence_acc_ms}")
+        if self.silence_acc_ms >= silence_limit_ms:
+            self.debug_log("=== end of segment")
+            # source utterance finished
+            self.silence_acc_ms = 0
+            self.speech_acc_ms = 0
+            if self.input_chunk.size > 0:
+                # flush partial input_chunk
+                self.input_queue.put_nowait(
+                    SpeechSegment(
+                        content=self.input_chunk, tgt_lang=tgt_lang, finished=True
+                    )
+                )
+                self.input_chunk = np.empty(0, dtype=np.int16)
+            self.input_queue.put_nowait(EmptySegment(finished=True))
+            self.source_finished = True
+
+    def update_source(
+        self, segment: Union[np.ndarray[Any, np.dtype[np.float32]], Segment]
+    ) -> None:
+        """
+        Default value for the segment in the update_source method is a segment
+        Class, for some reason this interface didn't align with other interfaces
+        Adding this change here to support both np.ndarray and Segment class
+        """
+        tgt_lang = None
+        if isinstance(segment, SpeechSegment):
+            self.sample_rate = segment.sample_rate
+            if hasattr(segment, "tgt_lang") and segment.tgt_lang is not None:
+                tgt_lang = segment.tgt_lang
+            if isinstance(segment.content, np.ndarray):
+                segment = np.array(segment.content, dtype=np.float32)
+            else:
+                segment = segment.content
+        speech_probs = self.get_speech_prob_from_np_float32(segment)
+        chunk_size_ms = len(segment) * 1000 / self.sample_rate
+        window_size_ms = self.window_size_samples * 1000 / self.sample_rate
+        if all(i <= SPEECH_PROB_THRESHOLD for i in speech_probs):
+            if self.source_finished:
+                return
+            self.debug_log("got silent chunk")
+            if not self.is_fresh_state:
+                self.silence_acc_ms += chunk_size_ms
+                self.check_silence_acc(tgt_lang)
+            return
+        elif speech_probs[-1] <= SPEECH_PROB_THRESHOLD:
+            self.debug_log("=== start of silence chunk")
+            # beginning = speech, end = silence
+            # pass to process_speech and accumulate silence
+            self.speech_acc_ms += chunk_size_ms
+            self.process_speech(segment, tgt_lang)
+            # accumulate contiguous silence
+            for i in range(len(speech_probs) - 1, -1, -1):
+                if speech_probs[i] > SPEECH_PROB_THRESHOLD:
+                    break
+                self.silence_acc_ms += window_size_ms
+            self.check_silence_acc(tgt_lang)
+        elif speech_probs[0] <= SPEECH_PROB_THRESHOLD:
+            self.debug_log("=== start of speech chunk")
+            # beginning = silence, end = speech
+            # accumulate silence , pass next to process_speech
+            for i in range(0, len(speech_probs)):
+                if speech_probs[i] > SPEECH_PROB_THRESHOLD:
+                    break
+                self.silence_acc_ms += window_size_ms
+            self.check_silence_acc(tgt_lang)
+            self.speech_acc_ms += chunk_size_ms
+            self.process_speech(segment, tgt_lang)
+        else:
+            self.speech_acc_ms += chunk_size_ms
+            self.debug_log("======== got speech chunk")
+            self.process_speech(segment, tgt_lang)
+
+    def debug_write_wav(self, chunk: np.ndarray[Any, Any]) -> None:
+        if self.test_input_segments_wav is not None:
+            self.test_input_segments_wav.seek(0, SEEK_END)
+            self.test_input_segments_wav.write(chunk)
+
+
+class SileroVADAgent(SpeechToSpeechAgent):
+    def __init__(self, args: Namespace) -> None:
+        super().__init__(args)
+        self.chunk_size_samples = args.chunk_size_samples
+        self.args = args
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--window-size-samples",
+            default=512,  # sampling_rate // 1000 * 32 => 32 ms at 16000 sample rate
+            type=int,
+            help="Window size for passing samples to VAD",
+        )
+        parser.add_argument(
+            "--chunk-size-samples",
+            default=5120,  # sampling_rate // 1000 * 320 => 320 ms at 16000 sample rate
+            type=int,
+            help="Chunk size for passing samples to model",
+        )
+        parser.add_argument(
+            "--silence-limit-ms",
+            default=700,
+            type=int,
+            help="send EOS to the input_queue after this amount of silence",
+        )
+        parser.add_argument(
+            "--speech-soft-limit-ms",
+            default=12_000,  # after 15s, increase the speech threshold
+            type=int,
+            help="after this amount of speech, decrease the speech threshold (segment more aggressively)",
+        )
+        parser.add_argument(
+            "--debug",
+            default=False,
+            type=bool,
+            help="Enable debug logs",
+        )
+
+    def build_states(self) -> SileroVADStates:
+        return SileroVADStates(self.args)
+
+    def policy(self, states: SileroVADStates) -> Action:
+        states.debug_log(
+            f"queue size: {states.input_queue.qsize()}, input_chunk size: {len(states.input_chunk)}"
+        )
+        content: np.ndarray[Any, Any] = np.empty(0, dtype=np.int16)
+        is_finished = states.source_finished
+        tgt_lang = None
+        while not states.input_queue.empty():
+            chunk = states.input_queue.get_nowait()
+            states.input_queue.task_done()
+            if tgt_lang is None:
+                tgt_lang = chunk.tgt_lang
+            content = np.concatenate((content, chunk.content))
+
+        states.debug_write_wav(content)
+        if is_finished:
+            states.debug_write_wav(np.zeros(16000))
+
+        if len(content) == 0:  # empty queue
+            if not states.source_finished:
+                return ReadAction()
+            else:
+                # NOTE: this should never happen, this logic is a safeguard
+                segment = EmptySegment(finished=True)
+        else:
+            segment = SpeechSegment(
+                content=content.tolist(),
+                finished=is_finished,
+                tgt_lang=tgt_lang,
+            )
+
+        return WriteAction(segment, finished=is_finished)
+
+    @classmethod
+    def from_args(cls, args: Namespace, **kwargs: None) -> SileroVADAgent:
+        return cls(args)