Ver Fonte

[streaming][demo] Increase speech threshold for start of speech (#187)

* Increase speech threshold for start of speech

* flag
Anna Sun há 1 ano atrás
pai
commit
e1fce00f10
1 ficheiros alterados com 16 adições e 5 exclusões
  1. 16 5
      src/seamless_communication/streaming/agents/silero_vad.py

+ 16 - 5
src/seamless_communication/streaming/agents/silero_vad.py

@@ -55,6 +55,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):  # type: ignore
         self.window_size_samples = args.window_size_samples
         self.chunk_size_samples = args.chunk_size_samples
         self.sample_rate = args.sample_rate
+        self.init_speech_prob = args.init_speech_prob
         self.debug = args.debug
         self.test_input_segments_wav = None
         self.debug_log(args)
@@ -190,7 +191,11 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):  # type: ignore
         chunk_size_ms = len(segment) * 1000 / self.sample_rate
         window_size_ms = self.window_size_samples * 1000 / self.sample_rate
         consecutive_silence_decay = False
-        if all(i <= SPEECH_PROB_THRESHOLD for i in speech_probs):
+        if self.is_fresh_state and self.init_speech_prob > 0:
+            threshold = SPEECH_PROB_THRESHOLD + self.init_speech_prob
+        else:
+            threshold = SPEECH_PROB_THRESHOLD
+        if all(i <= threshold for i in speech_probs):
             if self.source_finished:
                 return
             self.debug_log("got silent chunk")
@@ -198,7 +203,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):  # type: ignore
                 self.silence_acc_ms += chunk_size_ms
                 self.check_silence_acc(tgt_lang)
             return
-        elif speech_probs[-1] <= SPEECH_PROB_THRESHOLD:
+        elif speech_probs[-1] <= threshold:
             self.debug_log("=== start of silence chunk")
             # beginning = speech, end = silence
             # pass to process_speech and accumulate silence
@@ -208,16 +213,16 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):  # type: ignore
             self.process_speech(segment, tgt_lang)
             # accumulate contiguous silence
             for i in range(len(speech_probs) - 1, -1, -1):
-                if speech_probs[i] > SPEECH_PROB_THRESHOLD:
+                if speech_probs[i] > threshold:
                     break
                 self.silence_acc_ms += window_size_ms
             self.check_silence_acc(tgt_lang)
-        elif speech_probs[0] <= SPEECH_PROB_THRESHOLD:
+        elif speech_probs[0] <= threshold:
             self.debug_log("=== start of speech chunk")
             # beginning = silence, end = speech
             # accumulate silence , pass next to process_speech
             for i in range(0, len(speech_probs)):
-                if speech_probs[i] > SPEECH_PROB_THRESHOLD:
+                if speech_probs[i] > threshold:
                     break
                 self.silence_acc_ms += window_size_ms
             # try not to split right before speech
@@ -285,6 +290,12 @@ class SileroVADAgent(SpeechToSpeechAgent):  # type: ignore
             type=int,
             help="after this amount of speech, decrease the speech threshold (segment more aggressively)",
         )
+        parser.add_argument(
+            "--init-speech-prob",
+            default=0.15,
+            type=float,
+            help="Increase the initial speech probability threshold by this much at the start of speech",
+        )
         parser.add_argument(
             "--debug",
             default=False,