Quellcode durchsuchen

add max_len args in translator

Maha Elbayad vor 2 Jahren
Ursprung
Commit
096e7e01fc
1 geänderte Dateien mit 27 neuen und 7 gelöschten Zeilen
  1. 27 7
      src/seamless_communication/models/inference/translator.py

+ 27 - 7
src/seamless_communication/models/inference/translator.py

@@ -102,16 +102,28 @@ class Translator(nn.Module):
         output_modality: Modality,
         tgt_lang: str,
         ngram_filtering: bool = False,
+        text_max_len_a: int = 1,
+        text_max_len_b: int = 200,
+        unit_max_len_a: Optional[int] = None,
+        unit_max_len_b: Optional[int] = None,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
-        if input_modality == Modality.TEXT:
-            # need to adjust this since src_len is smaller for text.
-            max_len_a = 25
-        else:
-            max_len_a = 1
-        text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
+
+        if unit_max_len_a is None:
+            # need to adjust this for T2ST since src_len is smaller for text.
+            if input_modality == Modality.TEXT:
+                unit_max_len_a = 25
+            else:
+                unit_max_len_a = 1
+
+        text_opts = SequenceGeneratorOptions(
+            beam_size=5,
+            soft_max_seq_len=(text_max_len_a, text_max_len_b)
+        )
         unit_opts = SequenceGeneratorOptions(
-            beam_size=5, soft_max_seq_len=(max_len_a, 50)
+            beam_size=5,
+            soft_max_seq_len=(unit_max_len_a, unit_max_len_b or 50)
         )
+
         if ngram_filtering:
             text_opts.logits_processor = NGramRepeatBlockProcessor(
                 no_repeat_ngram_size=4
@@ -156,6 +168,10 @@ class Translator(nn.Module):
         spkr: Optional[int] = -1,
         ngram_filtering: bool = False,
         sample_rate: int = 16000,
+        text_max_len_a: int = 1,
+        text_max_len_b: int = 200,
+        unit_max_len_a: Optional[int] = None,
+        unit_max_len_b: Optional[int] = None,
     ) -> Tuple[StringLike, Optional[Tensor], Optional[int]]:
         """
         The main method used to perform inference on all tasks.
@@ -216,6 +232,10 @@ class Translator(nn.Module):
             output_modality,
             tgt_lang=tgt_lang,
             ngram_filtering=ngram_filtering,
+            text_max_len_a=text_max_len_a,
+            text_max_len_b=text_max_len_b,
+            unit_max_len_a=unit_max_len_a,
+            unit_max_len_b=unit_max_len_b,
         )
 
         text_out = result[0]