Эх сурвалжийг харах

ngram in scripts/predict & filename fix

cndn 2 жил өмнө
parent
commit
c8f9a8c0ba

+ 11 - 1
scripts/m4t/predict/predict.py

@@ -44,6 +44,12 @@ def main():
     parser.add_argument(
         "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
     )
+    parser.add_argument(
+        "--ngram-filtering",
+        type=bool,
+        help="Enable ngram_repeat_block (currently hardcoded to 4, during decoding) and ngram filtering over units (postprocessing)",
+        default=False,
+    )
 
     args = parser.parse_args()
 
@@ -59,7 +65,11 @@ def main():
 
     translator = Translator(args.model_name, args.vocoder_name, device)
     translated_text, wav, sr = translator.predict(
-        args.input, args.task, args.tgt_lang, src_lang=args.src_lang
+        args.input,
+        args.task,
+        args.tgt_lang,
+        src_lang=args.src_lang,
+        ngram_filtering=args.ngram_filtering,
     )
 
     if wav is not None and sr is not None:

+ 0 - 0
src/seamless_communication/models/inference/ngram_repeat_block_logits_processor.py → src/seamless_communication/models/inference/ngram_repeat_block_processor.py


+ 16 - 6
src/seamless_communication/models/inference/translator.py

@@ -18,7 +18,7 @@ from fairseq2.memory import MemoryBlock
 from fairseq2.typing import Device
 from torch import Tensor
 from enum import Enum, auto
-from seamless_communication.models.inference.ngram_repeat_block_logits_processor import (
+from seamless_communication.models.inference.ngram_repeat_block_processor import (
     NGramRepeatBlockProcessor,
 )
 
@@ -99,10 +99,16 @@ class Translator(nn.Module):
         else:
             max_len_a = 1
         text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
-        unit_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(max_len_a, 50))
+        unit_opts = SequenceGeneratorOptions(
+            beam_size=5, soft_max_seq_len=(max_len_a, 50)
+        )
         if ngram_filtering:
-            text_opts.logits_processor = NGramRepeatBlockProcessor(no_repeat_ngram_size=4)
-            unit_opts.logits_processor = NGramRepeatBlockProcessor(no_repeat_ngram_size=4)
+            text_opts.logits_processor = NGramRepeatBlockProcessor(
+                no_repeat_ngram_size=4
+            )
+            unit_opts.logits_processor = NGramRepeatBlockProcessor(
+                no_repeat_ngram_size=4
+            )
         generator = UnitYGenerator(
             model,
             text_tokenizer,
@@ -112,7 +118,11 @@ class Translator(nn.Module):
             unit_opts=unit_opts,
         )
         return generator(
-            src["seqs"], src["seq_lens"], input_modality.value, output_modality.value, ngram_filtering=ngram_filtering
+            src["seqs"],
+            src["seq_lens"],
+            input_modality.value,
+            output_modality.value,
+            ngram_filtering=ngram_filtering,
         )
 
     def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
@@ -204,7 +214,7 @@ class Translator(nn.Module):
             input_modality,
             output_modality,
             tgt_lang=tgt_lang,
-            ngram_filtering=ngram_filtering
+            ngram_filtering=ngram_filtering,
         )
 
         text_out = result[0]