|
@@ -6,7 +6,9 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
import logging
|
|
|
|
+from pathlib import Path
|
|
import queue
|
|
import queue
|
|
|
|
+import random
|
|
import time
|
|
import time
|
|
from argparse import ArgumentParser, Namespace
|
|
from argparse import ArgumentParser, Namespace
|
|
from os import SEEK_END
|
|
from os import SEEK_END
|
|
@@ -14,6 +16,7 @@ from typing import Any, List, Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
|
|
+import soundfile
|
|
from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
|
|
from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
|
|
from simuleval.agents import AgentStates, SpeechToSpeechAgent
|
|
from simuleval.agents import AgentStates, SpeechToSpeechAgent
|
|
from simuleval.agents.actions import Action, ReadAction, WriteAction
|
|
from simuleval.agents.actions import Action, ReadAction, WriteAction
|
|
@@ -78,6 +81,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
|
|
self.is_fresh_state = True
|
|
self.is_fresh_state = True
|
|
self.clear_queues()
|
|
self.clear_queues()
|
|
self.model.reset_states()
|
|
self.model.reset_states()
|
|
|
|
+ self.consecutive_silence_decay_count = 0
|
|
|
|
|
|
def reset_early(self) -> None:
|
|
def reset_early(self) -> None:
|
|
"""
|
|
"""
|
|
@@ -90,6 +94,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
|
|
) -> List[Any]:
|
|
) -> List[Any]:
|
|
t = torch.from_numpy(segment)
|
|
t = torch.from_numpy(segment)
|
|
speech_probs = []
|
|
speech_probs = []
|
|
|
|
+ # TODO: run self.model in batch?
|
|
for i in range(0, len(t), self.window_size_samples):
|
|
for i in range(0, len(t), self.window_size_samples):
|
|
chunk = t[i : i + self.window_size_samples]
|
|
chunk = t[i : i + self.window_size_samples]
|
|
if len(chunk) < self.window_size_samples:
|
|
if len(chunk) < self.window_size_samples:
|
|
@@ -116,11 +121,6 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
|
|
self.debug_log("use next_input_queue")
|
|
self.debug_log("use next_input_queue")
|
|
queue = self.next_input_queue
|
|
queue = self.next_input_queue
|
|
|
|
|
|
- # NOTE: we don't reset silence_acc_ms here so that once an utterance
|
|
|
|
- # becomes longer (accumulating more silence), it has a higher chance
|
|
|
|
- # of being segmented.
|
|
|
|
- self.silence_acc_ms = self.silence_acc_ms // 2
|
|
|
|
-
|
|
|
|
if self.first_input_ts is None:
|
|
if self.first_input_ts is None:
|
|
self.first_input_ts = time.time() * 1000
|
|
self.first_input_ts = time.time() * 1000
|
|
|
|
|
|
@@ -159,6 +159,12 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
|
|
self.input_chunk = np.empty(0, dtype=np.int16)
|
|
self.input_chunk = np.empty(0, dtype=np.int16)
|
|
self.input_queue.put_nowait(EmptySegment(finished=True))
|
|
self.input_queue.put_nowait(EmptySegment(finished=True))
|
|
self.source_finished = True
|
|
self.source_finished = True
|
|
|
|
+ self.debug_write_wav(np.empty(0, dtype=np.int16), finished=True)
|
|
|
|
+
|
|
|
|
+ def decay_silence_acc_ms(self):
|
|
|
|
+ if self.consecutive_silence_decay_count <= 2:
|
|
|
|
+ self.silence_acc_ms = self.silence_acc_ms // 2
|
|
|
|
+ self.consecutive_silence_decay_count += 1
|
|
|
|
|
|
def update_source(
|
|
def update_source(
|
|
self, segment: Union[np.ndarray[Any, np.dtype[np.float32]], Segment]
|
|
self, segment: Union[np.ndarray[Any, np.dtype[np.float32]], Segment]
|
|
@@ -180,6 +186,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
|
|
speech_probs = self.get_speech_prob_from_np_float32(segment)
|
|
speech_probs = self.get_speech_prob_from_np_float32(segment)
|
|
chunk_size_ms = len(segment) * 1000 / self.sample_rate
|
|
chunk_size_ms = len(segment) * 1000 / self.sample_rate
|
|
window_size_ms = self.window_size_samples * 1000 / self.sample_rate
|
|
window_size_ms = self.window_size_samples * 1000 / self.sample_rate
|
|
|
|
+ consecutive_silence_decay = False
|
|
if all(i <= SPEECH_PROB_THRESHOLD for i in speech_probs):
|
|
if all(i <= SPEECH_PROB_THRESHOLD for i in speech_probs):
|
|
if self.source_finished:
|
|
if self.source_finished:
|
|
return
|
|
return
|
|
@@ -193,6 +200,8 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
|
|
# beginning = speech, end = silence
|
|
# beginning = speech, end = silence
|
|
# pass to process_speech and accumulate silence
|
|
# pass to process_speech and accumulate silence
|
|
self.speech_acc_ms += chunk_size_ms
|
|
self.speech_acc_ms += chunk_size_ms
|
|
|
|
+ consecutive_silence_decay = True
|
|
|
|
+ self.decay_silence_acc_ms()
|
|
self.process_speech(segment, tgt_lang)
|
|
self.process_speech(segment, tgt_lang)
|
|
# accumulate contiguous silence
|
|
# accumulate contiguous silence
|
|
for i in range(len(speech_probs) - 1, -1, -1):
|
|
for i in range(len(speech_probs) - 1, -1, -1):
|
|
@@ -208,18 +217,37 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
|
|
if speech_probs[i] > SPEECH_PROB_THRESHOLD:
|
|
if speech_probs[i] > SPEECH_PROB_THRESHOLD:
|
|
break
|
|
break
|
|
self.silence_acc_ms += window_size_ms
|
|
self.silence_acc_ms += window_size_ms
|
|
|
|
+ # try not to split right before speech
|
|
|
|
+ self.silence_acc_ms = self.silence_acc_ms // 2
|
|
self.check_silence_acc(tgt_lang)
|
|
self.check_silence_acc(tgt_lang)
|
|
self.speech_acc_ms += chunk_size_ms
|
|
self.speech_acc_ms += chunk_size_ms
|
|
self.process_speech(segment, tgt_lang)
|
|
self.process_speech(segment, tgt_lang)
|
|
else:
|
|
else:
|
|
self.speech_acc_ms += chunk_size_ms
|
|
self.speech_acc_ms += chunk_size_ms
|
|
self.debug_log("======== got speech chunk")
|
|
self.debug_log("======== got speech chunk")
|
|
|
|
+ consecutive_silence_decay = True
|
|
|
|
+ self.decay_silence_acc_ms()
|
|
self.process_speech(segment, tgt_lang)
|
|
self.process_speech(segment, tgt_lang)
|
|
|
|
+ if not consecutive_silence_decay:
|
|
|
|
+ self.consecutive_silence_decay_count = 0
|
|
|
|
|
|
- def debug_write_wav(self, chunk: np.ndarray[Any, Any]) -> None:
|
|
|
|
|
|
+ def debug_write_wav(
|
|
|
|
+ self, chunk: np.ndarray[Any, Any], finished: bool = False
|
|
|
|
+ ) -> None:
|
|
if self.test_input_segments_wav is not None:
|
|
if self.test_input_segments_wav is not None:
|
|
self.test_input_segments_wav.seek(0, SEEK_END)
|
|
self.test_input_segments_wav.seek(0, SEEK_END)
|
|
self.test_input_segments_wav.write(chunk)
|
|
self.test_input_segments_wav.write(chunk)
|
|
|
|
+ if finished:
|
|
|
|
+ MODEL_SAMPLE_RATE = 16_000
|
|
|
|
+ debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
|
|
|
|
+ self.test_input_segments_wav = soundfile.SoundFile(
|
|
|
|
+ Path(self.test_input_segments_wav.name).parent
|
|
|
|
+ / f"{debug_ts}_test_input_segments.wav",
|
|
|
|
+ mode="w+",
|
|
|
|
+ format="WAV",
|
|
|
|
+ samplerate=MODEL_SAMPLE_RATE,
|
|
|
|
+ channels=1,
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
class SileroVADAgent(SpeechToSpeechAgent):
|
|
class SileroVADAgent(SpeechToSpeechAgent):
|
|
@@ -279,8 +307,6 @@ class SileroVADAgent(SpeechToSpeechAgent):
|
|
content = np.concatenate((content, chunk.content))
|
|
content = np.concatenate((content, chunk.content))
|
|
|
|
|
|
states.debug_write_wav(content)
|
|
states.debug_write_wav(content)
|
|
- if is_finished:
|
|
|
|
- states.debug_write_wav(np.zeros(16000))
|
|
|
|
|
|
|
|
if len(content) == 0: # empty queue
|
|
if len(content) == 0: # empty queue
|
|
if not states.source_finished:
|
|
if not states.source_finished:
|