Przeglądaj źródła

handle pretssel vocoder output langs (#261)

Anna Sun 1 rok temu
rodzic
commit
a7749e5b64

+ 15 - 8
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -50,6 +50,7 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ign
 
         vocoder_model_card = asset_store.retrieve_card(args.vocoder_name)
         self.vocoder_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
+        self.vocoder_langs = vocoder_model_card.field("model_config").field("langs").as_list(str)
 
         self.upstream_idx = args.upstream_idx
         self.sample_rate = args.sample_rate  # input sample rate
@@ -115,19 +116,25 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ign
 
         tgt_lang = states.tgt_lang if states.tgt_lang else self.tgt_lang
 
-        wav = self.vocoder(
-            unit,
-            tgt_lang=tgt_lang,
-            prosody_input_seqs=feats,
-            durations=duration.unsqueeze(0),
-            normalize_before=True,
-        )
+        
+        if tgt_lang not in self.vocoder_langs:
+            logger.warning(f"{tgt_lang} not supported!")
+            content = []
+        else:
+            wav = self.vocoder(
+                unit,
+                tgt_lang=tgt_lang,
+                prosody_input_seqs=feats,
+                durations=duration.unsqueeze(0),
+                normalize_before=True,
+            )
+            content = wav[0][0][0].tolist()
 
         states.source = []
 
         return WriteAction(
             SpeechSegment(
-                content=wav[0][0][0].tolist(),
+                content=content,
                 finished=states.source_finished,
                 sample_rate=self.vocoder_sample_rate,
                 tgt_lang=tgt_lang,