瀏覽代碼

Add unk_penalty and len_penalty back to SequenceGeneratorOptions (#176)

* Add unk_penalty and len_penalty back to SequenceGeneratorOptions

* Add text unk blocking option in the predict cli

---------

Co-authored-by: Peng-Jen Chen <pipibjc@devfair0209.h2.fair>
Co-authored-by: Peng-Jen Chen <pipibjc@learnfair2274.h2.fair>
pipibjc 1 年之前
父節點
當前提交
dc11a4d74e

+ 11 - 0
src/seamless_communication/cli/m4t/predict/predict.py

@@ -135,6 +135,15 @@ def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.Argumen
         ),
         default=False,
     )
+    parser.add_argument(
+        "--text-unk-blocking",
+        type=bool,
+        help=(
+            "If True, set penalty of UNK to inf in text generator "
+            "to block unk output."
+        ),
+        default=False,
+    )
     return parser
 
 
@@ -149,6 +158,8 @@ def set_generation_opts(
             args.text_generation_max_len_b,
         ),
     )
+    if args.text_unk_blocking:
+        text_generation_opts.unk_penalty = torch.inf
     if args.text_generation_ngram_blocking:
         text_generation_opts.step_processor = NGramRepeatBlockProcessor(
             ngram_size=args.no_repeat_ngram_size

+ 14 - 0
src/seamless_communication/inference/generator.py

@@ -75,6 +75,14 @@ class SequenceGeneratorOptions:
     step_processor: Optional[StepProcessor] = None
     """The processor called at each generation step."""
 
+    unk_penalty: float = 0.0
+    """The UNK symbol penalty, where values less than 0 produce more UNKs;
+    values greater than 0 produce fewer UNKs."""
+
+    len_penalty: float = 1.0
+    """The length penalty, where values less than 1.0 favor shorter
+    sequences; values greater than 1.0 favor longer sequences."""
+
 
 class UnitYGenerator:
     """Generates text translations and speech units from a UnitY model."""
@@ -143,6 +151,8 @@ class UnitYGenerator:
             max_seq_len=text_opts.hard_max_seq_len,
             echo_prompt=True,
             step_processors=step_processors,
+            unk_penalty=text_opts.unk_penalty,
+            len_penalty=text_opts.len_penalty,
         )
         self.s2t_converter = SequenceToTextConverter(
             generator, text_tokenizer, "translation", target_lang
@@ -168,6 +178,8 @@ class UnitYGenerator:
                 max_seq_len=text_opts.hard_max_seq_len,
                 echo_prompt=True,
                 step_processors=step_processors,
+                unk_penalty=text_opts.unk_penalty,
+                len_penalty=text_opts.len_penalty,
             )
             self.t2t_converter = SequenceToTextConverter(
                 generator, text_tokenizer, "translation", target_lang
@@ -208,6 +220,8 @@ class UnitYGenerator:
                     max_seq_len=unit_opts.hard_max_seq_len,
                     echo_prompt=True,
                     step_processors=step_processors,
+                    unk_penalty=unit_opts.unk_penalty,
+                    len_penalty=unit_opts.len_penalty,
                 )
 
     @torch.inference_mode()