Browse Source

Bump simuleval version, override simuleval update_target to save memory (#133)

* port recent changes from fairseq1 for streaming demo

* lint + annotations

* fix eval with tgt-lang only

* bump simuleval version, override simuleval update_target behavior to save memory

* version comment
Anna Sun 1 year ago
parent
commit
d877073d7c

+ 1 - 1
setup.py

@@ -25,7 +25,7 @@ setup(
         "fairseq2==0.2.*",
         "librosa",
         "openai-whisper",
-        "simuleval",
+        "simuleval~=1.1.1",
         "soundfile",
         "torchaudio",
         "tqdm",

+ 34 - 0
src/seamless_communication/streaming/agents/common.py

@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Mixins + common for fairseq2 simuleval agents
+"""
+
+from simuleval.data.segments import Segment
+from simuleval.agents.states import AgentStates as AgentStatesOrig
+
+
+class EarlyStoppingMixin:
+    def reset_early(self) -> None:
+        """
+        Implement to override for different behavior on a reset that
+        happens before EOS
+        """
+        raise NotImplementedError()
+
+
+class AgentStates(AgentStatesOrig):
+    def update_target(self, segment: Segment):
+        """An AgentStates impl which doesn't update states.target"""
+        self.target_finished = segment.finished
+
+
+class NoUpdateTargetMixin:
+    """A shortcut to make agents default to the AgentStates impl above"""
+
+    def build_states(self) -> AgentStates:
+        return AgentStates()

+ 5 - 2
src/seamless_communication/streaming/agents/detokenizer.py

@@ -10,10 +10,13 @@ from typing import Any, Dict
 
 from simuleval.agents import TextToTextAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
-from simuleval.agents.states import AgentStates
+from seamless_communication.streaming.agents.common import (
+    AgentStates,
+    NoUpdateTargetMixin,
+)
 
 
-class DetokenizerAgent(TextToTextAgent):
+class DetokenizerAgent(TextToTextAgent, NoUpdateTargetMixin):
     def __init__(self, args: Namespace):
         super().__init__(args)
         self.detokenize_only = args.detokenize_only

+ 0 - 18
src/seamless_communication/streaming/agents/mixins.py

@@ -1,18 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Mixins for fairseq2 simuleval agents
-"""
-
-
-class EarlyStoppingMixin:
-    def reset_early(self) -> None:
-        """
-        Implement to override for different behavior on a reset that
-        happens before EOS
-        """
-        raise NotImplementedError()

+ 6 - 2
src/seamless_communication/streaming/agents/offline_w2v_bert_encoder.py

@@ -15,12 +15,16 @@ from fairseq2.data.text import TextTokenizer
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from seamless_communication.models.unity.model import UnitYModel
-from simuleval.agents import AgentStates, SpeechToSpeechAgent
+from simuleval.agents import SpeechToSpeechAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
 from simuleval.data.segments import SpeechSegment
+from seamless_communication.streaming.agents.common import (
+    AgentStates,
+    NoUpdateTargetMixin,
+)
 
 
-class OfflineWav2VecBertEncoderAgent(SpeechToSpeechAgent):
+class OfflineWav2VecBertEncoderAgent(SpeechToSpeechAgent, NoUpdateTargetMixin):
     """
     Incremental encoding of an wav2vec encoder output
     It update the whole encoder states every time when there is a new incoming segment.

+ 1 - 1
src/seamless_communication/streaming/agents/online_feature_extractor.py

@@ -16,8 +16,8 @@ from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
 
 from simuleval.agents import SpeechToSpeechAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
-from simuleval.agents.states import AgentStates
 from simuleval.data.segments import Segment, SpeechSegment
+from seamless_communication.streaming.agents.common import AgentStates
 
 
 SHIFT_SIZE = 10

+ 1 - 1
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -17,7 +17,7 @@ from seamless_communication.models.monotonic_decoder import (
 )
 from simuleval.agents import GenericAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
-from simuleval.agents.states import AgentStates
+from seamless_communication.streaming.agents.common import AgentStates
 from simuleval.data.segments import Segment, TextSegment
 from torch import Tensor
 

+ 5 - 2
src/seamless_communication/streaming/agents/silero_vad.py

@@ -17,8 +17,11 @@ from typing import Any, List, Optional, Union
 import numpy as np
 import torch
 import soundfile
-from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
-from simuleval.agents import AgentStates, SpeechToSpeechAgent
+from seamless_communication.streaming.agents.common import (
+    AgentStates,
+    EarlyStoppingMixin,
+)
+from simuleval.agents import SpeechToSpeechAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
 from simuleval.data.segments import EmptySegment, Segment, SpeechSegment
 

+ 5 - 2
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -22,8 +22,11 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
-from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
-from simuleval.agents import AgentPipeline, AgentStates
+from seamless_communication.streaming.agents.common import (
+    AgentStates,
+    EarlyStoppingMixin,
+)
+from simuleval.agents import AgentPipeline
 from simuleval.agents.agent import GenericAgent
 from simuleval.data.segments import Segment