Browse Source

[streaming] Port changes for streaming demo (#130)

* port recent changes from fairseq1 for streaming demo

* lint + annotations

* fix eval with tgt-lang only
Anna Sun 1 năm trước cách đây
mục cha
commit
2ccf28ad24

+ 5 - 1
src/seamless_communication/streaming/agents/offline_w2v_bert_encoder.py

@@ -62,10 +62,14 @@ class OfflineWav2VecBertEncoderAgent(SpeechToSpeechAgent):
         The policy for encoder is always write
         only if the input is too short
         """
-        if len(states.source) < self.min_input_length or (
+        if (
             self.min_starting_wait is not None
             and len(states.source) < self.min_starting_wait
+            and not states.source_finished
         ):
+            return ReadAction()
+
+        if len(states.source) < self.min_input_length:
             if states.source_finished:
                 return WriteAction({}, finished=states.source_finished)
             else:

+ 64 - 21
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -6,7 +6,7 @@
 from __future__ import annotations
 
 from argparse import ArgumentParser, Namespace
-from typing import Any, Dict, List, Tuple
+from typing import Any, Dict, List, Set, Tuple
 
 import torch
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
@@ -27,6 +27,7 @@ class DecoderAgentStates(AgentStates):
         self.source_len = 0
         self.target_indices: List[int] = []
         self.tgt_lang = None
+        self.ngram_block_count = 0
         super().reset()
 
     def update_source(self, segment: Segment) -> None:
@@ -76,9 +77,11 @@ class OnlineTextDecoderAgent(GenericAgent):
         self.device = args.device
         self.dtype = args.dtype
         self.eos_idx = text_tokenizer.vocab_info.eos_idx
-        if hasattr(args, "tgt_lang") and hasattr(args, "prefix_tgt_lang"):
+        if getattr(args, "tgt_lang", None) and getattr(args, "prefix_tgt_lang", None):
             assert args.tgt_lang == args.prefix_tgt_lang
-        tgt_lang = getattr(args, "tgt_lang", None) or getattr(args, "prefix_tgt_lang", None)
+        tgt_lang = getattr(args, "tgt_lang", None) or getattr(
+            args, "prefix_tgt_lang", None
+        )
         token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
         prefix_indices = token_encoder.prefix_indices
         assert prefix_indices is not None
@@ -116,12 +119,6 @@ class OnlineTextDecoderAgent(GenericAgent):
             default=1,
             help="Minimal starting waiting source steps",
         )
-        parser.add_argument(
-            "--min-starting-wait-reset",
-            type=int,
-            default=0,
-            help="Minimal starting waiting source steps",
-        )
         parser.add_argument(
             "--no-early-stop",
             action="store_true",
@@ -157,6 +154,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
 
         self.decision_threshold = args.decision_threshold
         self.decision_method = args.decision_method
+        self.block_ngrams = args.block_ngrams
         self.p_choose_start_layer = args.p_choose_start_layer
 
     @staticmethod
@@ -181,6 +179,10 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             default=0,
             help="Encoder layer from which p_choose should be considered for selection.",
         )
+        parser.add_argument(
+            "--block-ngrams",
+            action="store_true",
+        )
 
     @classmethod
     def from_args(
@@ -224,6 +226,10 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         )
 
         logits = self.model.project(decoder_output)
+        if self.block_ngrams and states.source_finished:
+            all_indices = states.target_indices + pred_indices
+            blocked_indices = all_indices[-4:]
+            logits[:, :, blocked_indices] = float("-inf")
 
         index = int(logits[0, -1].argmax().item())
         _, tgt_len, src_len = p_choose.size()
@@ -250,6 +256,46 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             tgt_lang=states.tgt_lang,
         )
 
+    def get_blocked_ngrams(self, target_indices: List[int]):
+        # TODO: make it configurable and use itertools
+        if not self.block_ngrams:
+            return None
+        blocked_ngrams = set()
+        if len(target_indices) >= 4:
+            blocked_ngrams.add(str(target_indices[-4:]))
+            blocked_ngrams.add(str(target_indices[-4:-2]))
+            blocked_ngrams.add(str(target_indices[-4:-1]))
+        if len(target_indices) >= 3:
+            blocked_ngrams.add(str(target_indices[-3:]))
+            blocked_ngrams.add(str(target_indices[-3:-1]))
+        if len(target_indices) >= 2:
+            blocked_ngrams.add(str(target_indices[-2:]))
+        return blocked_ngrams
+
+    def maybe_block_ngrams(
+        self,
+        states: DecoderAgentStates,
+        pred_indices: List[int],
+        blocked_ngrams: Set[int],
+        index: int,
+    ):
+        """
+        This check is used to force a READ decision when n-gram repeat
+        happens before source_finished
+        """
+        if not self.block_ngrams or states.source_finished:
+            return False
+        all_indices = states.target_indices + pred_indices + [index]
+        for n in [3, 2]:  # TODO: make it configurable
+            if len(all_indices) >= n and states.ngram_block_count <= 4:
+                if str(all_indices[-n:]) in blocked_ngrams:
+                    states.ngram_block_count += 1
+                    pred_indices[:] = pred_indices[: -(n - 1)]
+                    # decoder_features_out = decoder_features_out[:, : -(n - 1)]
+                    return True
+                blocked_ngrams.add(str(all_indices[-n:]))
+        return False
+
     @torch.inference_mode()
     def policy(self, states: DecoderAgentStates) -> Action:
         if len(states.source) == 0:
@@ -272,6 +318,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         index = None
         prob = None
         finished = False
+        blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
 
         while (
             len(states.target_indices + pred_indices) < self.max_len(states)
@@ -281,18 +328,17 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
 
             if (
                 self.no_early_stop
-                and prob < self.decision_threshold
-                and not states.source_finished
-            ):
-                break
-            if (
-                self.no_early_stop
-                and index == self.eos_idx
                 and not states.source_finished
+                and (prob < self.decision_threshold or index == self.eos_idx)
             ):
                 if prob == 1.0:
                     pred_indices = []
                 break
+            block_ngram = self.maybe_block_ngrams(
+                states, pred_indices, blocked_ngrams, index
+            )  # TODO: add back decoder_features_out processing for unity2
+            if block_ngram:
+                break
             if (
                 finished
                 or index == self.eos_idx
@@ -301,11 +347,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
                 finished = True
                 break
 
-            if (
-                not self.no_early_stop
-                and prob < self.decision_threshold
-                and not states.source_finished
-            ):
+            if prob < self.decision_threshold and not states.source_finished:
                 break
 
             pred_indices.append(index)
@@ -322,6 +364,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             finished = finished or len(
                 states.target_indices + pred_indices
             ) > self.max_len(states)
+            states.ngram_block_count = 0
             return WriteAction(
                 self.postprocess(states, torch.tensor(pred_indices), finished),
                 finished=finished,

+ 34 - 8
src/seamless_communication/streaming/agents/silero_vad.py

@@ -6,7 +6,9 @@
 from __future__ import annotations
 
 import logging
+from pathlib import Path
 import queue
+import random
 import time
 from argparse import ArgumentParser, Namespace
 from os import SEEK_END
@@ -14,6 +16,7 @@ from typing import Any, List, Optional, Union
 
 import numpy as np
 import torch
+import soundfile
 from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
 from simuleval.agents import AgentStates, SpeechToSpeechAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
@@ -78,6 +81,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
         self.is_fresh_state = True
         self.clear_queues()
         self.model.reset_states()
+        self.consecutive_silence_decay_count = 0
 
     def reset_early(self) -> None:
         """
@@ -90,6 +94,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
     ) -> List[Any]:
         t = torch.from_numpy(segment)
         speech_probs = []
+        # TODO: run self.model in batch?
         for i in range(0, len(t), self.window_size_samples):
             chunk = t[i : i + self.window_size_samples]
             if len(chunk) < self.window_size_samples:
@@ -116,11 +121,6 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
             self.debug_log("use next_input_queue")
             queue = self.next_input_queue
 
-        # NOTE: we don't reset silence_acc_ms here so that once an utterance
-        # becomes longer (accumulating more silence), it has a higher chance
-        # of being segmented.
-        self.silence_acc_ms = self.silence_acc_ms // 2
-
         if self.first_input_ts is None:
             self.first_input_ts = time.time() * 1000
 
@@ -159,6 +159,12 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
                 self.input_chunk = np.empty(0, dtype=np.int16)
             self.input_queue.put_nowait(EmptySegment(finished=True))
             self.source_finished = True
+            self.debug_write_wav(np.empty(0, dtype=np.int16), finished=True)
+
+    def decay_silence_acc_ms(self):
+        if self.consecutive_silence_decay_count <= 2:
+            self.silence_acc_ms = self.silence_acc_ms // 2
+            self.consecutive_silence_decay_count += 1
 
     def update_source(
         self, segment: Union[np.ndarray[Any, np.dtype[np.float32]], Segment]
@@ -180,6 +186,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
         speech_probs = self.get_speech_prob_from_np_float32(segment)
         chunk_size_ms = len(segment) * 1000 / self.sample_rate
         window_size_ms = self.window_size_samples * 1000 / self.sample_rate
+        consecutive_silence_decay = False
         if all(i <= SPEECH_PROB_THRESHOLD for i in speech_probs):
             if self.source_finished:
                 return
@@ -193,6 +200,8 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
             # beginning = speech, end = silence
             # pass to process_speech and accumulate silence
             self.speech_acc_ms += chunk_size_ms
+            consecutive_silence_decay = True
+            self.decay_silence_acc_ms()
             self.process_speech(segment, tgt_lang)
             # accumulate contiguous silence
             for i in range(len(speech_probs) - 1, -1, -1):
@@ -208,18 +217,37 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
                 if speech_probs[i] > SPEECH_PROB_THRESHOLD:
                     break
                 self.silence_acc_ms += window_size_ms
+            # try not to split right before speech
+            self.silence_acc_ms = self.silence_acc_ms // 2
             self.check_silence_acc(tgt_lang)
             self.speech_acc_ms += chunk_size_ms
             self.process_speech(segment, tgt_lang)
         else:
             self.speech_acc_ms += chunk_size_ms
             self.debug_log("======== got speech chunk")
+            consecutive_silence_decay = True
+            self.decay_silence_acc_ms()
             self.process_speech(segment, tgt_lang)
+        if not consecutive_silence_decay:
+            self.consecutive_silence_decay_count = 0
 
-    def debug_write_wav(self, chunk: np.ndarray[Any, Any]) -> None:
+    def debug_write_wav(
+        self, chunk: np.ndarray[Any, Any], finished: bool = False
+    ) -> None:
         if self.test_input_segments_wav is not None:
             self.test_input_segments_wav.seek(0, SEEK_END)
             self.test_input_segments_wav.write(chunk)
+            if finished:
+                MODEL_SAMPLE_RATE = 16_000
+                debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
+                self.test_input_segments_wav = soundfile.SoundFile(
+                    Path(self.test_input_segments_wav.name).parent
+                    / f"{debug_ts}_test_input_segments.wav",
+                    mode="w+",
+                    format="WAV",
+                    samplerate=MODEL_SAMPLE_RATE,
+                    channels=1,
+                )
 
 
 class SileroVADAgent(SpeechToSpeechAgent):
@@ -279,8 +307,6 @@ class SileroVADAgent(SpeechToSpeechAgent):
             content = np.concatenate((content, chunk.content))
 
         states.debug_write_wav(content)
-        if is_finished:
-            states.debug_write_wav(np.zeros(16000))
 
         if len(content) == 0:  # empty queue
             if not states.source_finished: