فهرست منبع

Fix text, unit generation opts to not use mutable default arguments. (#166)

Kaushik Ram Sadagopan 1 سال پیش
والد
کامیت
08710925e0
2فایلهای تغییر یافته به همراه37 افزوده شده و 18 حذف شده
  1. 26 12
      src/seamless_communication/inference/translator.py
  2. 11 6
      src/seamless_communication/toxicity/mintox.py

+ 26 - 12
src/seamless_communication/inference/translator.py

@@ -221,12 +221,8 @@ class Translator(nn.Module):
         task_str: str,
         tgt_lang: str,
         src_lang: Optional[str] = None,
-        text_generation_opts: SequenceGeneratorOptions = SequenceGeneratorOptions(
-            beam_size=5, soft_max_seq_len=(1, 200)
-        ),
-        unit_generation_opts: Optional[
-            SequenceGeneratorOptions
-        ] = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(25, 50)),
+        text_generation_opts: Optional[SequenceGeneratorOptions] = None,
+        unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
         spkr: Optional[int] = -1,
         sample_rate: int = 16000,
         unit_generation_ngram_filtering: bool = False,
@@ -313,6 +309,15 @@ class Translator(nn.Module):
 
         seqs, padding_mask = get_seqs_and_padding_mask(src)
 
+        if text_generation_opts is None:
+            text_generation_opts = SequenceGeneratorOptions(
+                beam_size=5, soft_max_seq_len=(1, 200)
+            )
+        if unit_generation_opts is None:
+            unit_generation_opts = SequenceGeneratorOptions(
+                beam_size=5, soft_max_seq_len=(25, 50)
+            )
+
         text_output, unit_output = self.get_prediction(
             self.model,
             self.text_tokenizer,
@@ -345,9 +350,9 @@ class Translator(nn.Module):
                         sample_rate=sample_rate,
                         unit_generation_ngram_filtering=unit_generation_ngram_filtering,
                     )
-                    src_texts = [asr_text]
+                    src_texts = [str(asr_text)]
             else:
-                src_texts = [input]
+                src_texts = [str(input)]
 
             text_output, unit_output = mintox_pipeline(
                 model=self.model,
@@ -389,18 +394,27 @@ class Translator(nn.Module):
             audio_wavs = []
             speech_units = []
             for i in range(len(units)):
-                padding_mask = (
+                assert self.model.t2u_model is not None
+                unit_padding_mask = (
                     units[i] != self.model.t2u_model.target_vocab_info.pad_idx
                 )
-                u = units[i][padding_mask]
+                u = units[i][unit_padding_mask]
                 speech_units.append(u.tolist())
-            
+
             if self.vocoder is not None:
                 translated_audio_wav = self.vocoder(
                     units, tgt_lang, spkr, dur_prediction=duration_prediction
                 )
                 for i in range(len(units)):
-                    padding_removed_audio_wav = translated_audio_wav[i, :, :int(translated_audio_wav.size(-1)*len(speech_units[i])/len(units[i]))].unsqueeze(0)
+                    padding_removed_audio_wav = translated_audio_wav[
+                        i,
+                        :,
+                        : int(
+                            translated_audio_wav.size(-1)
+                            * len(speech_units[i])
+                            / len(units[i])
+                        ),
+                    ].unsqueeze(0)
                     audio_wavs.append(padding_removed_audio_wav)
             return (
                 text_output.sentences,

+ 11 - 6
src/seamless_communication/toxicity/mintox.py

@@ -143,12 +143,8 @@ def mintox_pipeline(
     original_text_out: SequenceToTextOutput,
     original_unit_out: Optional[SequenceToUnitOutput] = None,
     unit_generation_ngram_filtering: bool = False,
-    text_generation_opts: SequenceGeneratorOptions = SequenceGeneratorOptions(
-        beam_size=5, soft_max_seq_len=(1, 200)
-    ),
-    unit_generation_opts: Optional[SequenceGeneratorOptions] = SequenceGeneratorOptions(
-        beam_size=5, soft_max_seq_len=(25, 50)
-    ),
+    text_generation_opts: Optional[SequenceGeneratorOptions] = None,
+    unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
     bad_word_checker: ETOXBadWordChecker = None,
     duration_factor: float = 1.0,
     prosody_encoder_input: Optional[SequenceData] = None,
@@ -156,6 +152,15 @@ def mintox_pipeline(
     """MinTox: Mitigation at INference time of added TOXicity."""
     from seamless_communication.inference.translator import Modality, Translator
 
+    if text_generation_opts is None:
+        text_generation_opts = SequenceGeneratorOptions(
+            beam_size=5, soft_max_seq_len=(1, 200)
+        )
+    if unit_generation_opts is None:
+        unit_generation_opts = SequenceGeneratorOptions(
+            beam_size=5, soft_max_seq_len=(25, 50)
+        )
+
     def _get_banned_sequence_processor(
         banned_sequences: List[str],
     ) -> BannedSequenceProcessor: