Browse Source

[streaming] add s2s + s2t expressive demo (#153)

* add s2s + s2t expressive

* bump simuleval version
Anna Sun 1 year ago
parent
commit
14b013315e

+ 1 - 1
setup.py

@@ -25,7 +25,7 @@ setup(
         "fairseq2==0.2.*",
         "librosa",
         "openai-whisper",
-        "simuleval~=1.1.1",
+        "simuleval~=1.1.2",
         "soundfile",
         "torchaudio",
         "tqdm",

+ 12 - 0
src/seamless_communication/streaming/agents/mma_m4t_s2st.py

@@ -69,3 +69,15 @@ class MonotonicM4TS2STJointVADAgent(UnitYAgentTreePipeline):
         NARUnitYUnitDecoderAgent: [VocoderAgent],
         VocoderAgent: [],
     }
+
+
+class SeamlessS2STJointVADAgent(UnitYAgentTreePipeline):
+    pipeline = {
+        SileroVADAgent: [OnlineFeatureExtractorAgent],
+        OnlineFeatureExtractorAgent: [OfflineWav2VecBertEncoderAgent],
+        OfflineWav2VecBertEncoderAgent: [UnitYMMATextDecoderAgent],
+        UnitYMMATextDecoderAgent: [UnitYDetokenizerAgent, NARUnitYUnitDecoderAgent],
+        UnitYDetokenizerAgent: [],
+        NARUnitYUnitDecoderAgent: [PretsselVocoderAgent],
+        PretsselVocoderAgent: [],
+    }

+ 10 - 3
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -24,7 +24,8 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
         super().__init__(args)
         self.vocoder = vocoder
         self.upstream_idx = args.upstream_idx
-        self.sample_rate = args.sample_rate
+        self.sample_rate = args.sample_rate  # input sample rate
+        self.vocoder_sample_rate = args.vocoder_sample_rate  # output sample rate
         self.tgt_lang = args.tgt_lang
         self.convert_to_fbank = WaveformToFbankConverter(
             num_mel_bins=80,
@@ -72,7 +73,7 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
 
         audio_dict = {
             "waveform": torch.tensor(source, dtype=torch.float32, device=self.device).unsqueeze(1),
-            "sample_rate": 16000, # input audio is fixed to 16kHZ
+            "sample_rate": self.sample_rate,
             "format": -1,
         }
 
@@ -96,7 +97,7 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
             SpeechSegment(
                 content=wav[0][0].tolist(),
                 finished=states.source_finished,
-                sample_rate=self.sample_rate,
+                sample_rate=self.vocoder_sample_rate,
                 tgt_lang=tgt_lang,
             ),
             finished=states.source_finished,
@@ -110,6 +111,12 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
             default=0,
             help="index of encoder states where states.source contains input audio",
         )
+        parser.add_argument(
+            "--vocoder-sample-rate",
+            type=int,
+            default=16000,
+            help="sample rate out of the vocoder"
+        )
 
     @classmethod
     def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> PretsselVocoderAgent: