Selaa lähdekoodia

Merge pull request #15 from facebookresearch/ngram

Port ngram_repeat_block and ngram cleaning over units to SC
Ning 2 vuotta sitten
vanhempi
commit
1604c9ebf6

+ 11 - 1
scripts/m4t/predict/predict.py

@@ -44,6 +44,12 @@ def main():
     parser.add_argument(
         "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
     )
+    parser.add_argument(
+        "--ngram-filtering",
+        type=bool,
+        help="Enable ngram_repeat_block (currently hardcoded to 4, during decoding) and ngram filtering over units (postprocessing)",
+        default=False,
+    )
 
     args = parser.parse_args()
 
@@ -59,7 +65,11 @@ def main():
 
     translator = Translator(args.model_name, args.vocoder_name, device)
     translated_text, wav, sr = translator.predict(
-        args.input, args.task, args.tgt_lang, src_lang=args.src_lang
+        args.input,
+        args.task,
+        args.tgt_lang,
+        src_lang=args.src_lang,
+        ngram_filtering=args.ngram_filtering,
     )
 
     if wav is not None and sr is not None:

+ 61 - 0
src/seamless_communication/models/inference/ngram_repeat_block_processor.py

@@ -0,0 +1,61 @@
+from fairseq2.generation.logits_processor import LogitsProcessor as LogitsProcessor
+from typing import List
+from torch import Tensor
+import torch
+
+
+class NGramRepeatBlockProcessor(LogitsProcessor):
+    def __init__(self, no_repeat_ngram_size: int) -> None:
+        self.no_repeat_ngram_size = no_repeat_ngram_size
+
+    def __call__(self, seqs: Tensor, lprobs: Tensor) -> None:
+        """Remove repeating n-gram tokens."""
+        batch_size, beam_size, vocab_size = lprobs.size()
+        step_nr = seqs.size(2) - 1
+        # (N, B, S) -> (N * B, S)
+        seqs = seqs.view(-1, seqs.size(2))
+        # (N, B, V) -> (N * B, V)
+        lprobs = lprobs.view(-1, vocab_size)
+        self._no_repeat_ngram(seqs, lprobs, batch_size, beam_size, step_nr)
+
+    def _no_repeat_ngram(
+        self,
+        seqs: Tensor,
+        lprobs: Tensor,
+        batch_size: int,
+        beam_size: int,
+        step_nr: int,
+    ) -> Tensor:
+        """For each hypothesis generate a list of previous ngrams
+            and set associated lprobs to -inf
+
+        :param seqs: The generated sequences of tokens for the first
+            `step_nr` steps of decoding (N * B, step_nr + 1)
+        :param lprobs: The next-step log probability reshaped to (N * B, V)
+        :param batch_size: The batch size.
+        :param beam_size: The beam size.
+        :param step_nr: Step number for decoding.
+
+        :returns:
+            modified lprobs tensor with banned tokens set to -inf
+        """
+        banned_tokens = [[] for _ in range(batch_size * beam_size)]
+
+        if step_nr + 2 - self.no_repeat_ngram_size >= 0:
+            cpu_tokens: List[List[int]] = seqs.cpu().tolist()
+            check_start_pos = step_nr + 2 - self.no_repeat_ngram_size
+            for bbsz_idx in range(batch_size * beam_size):
+                ngram_to_check = cpu_tokens[bbsz_idx][
+                    -(self.no_repeat_ngram_size - 1) :
+                ]
+                for i in range(check_start_pos):
+                    if (
+                        ngram_to_check
+                        == cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1]
+                    ):
+                        banned_tokens[bbsz_idx].append(
+                            cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
+                        )
+        for bbsz_idx in range(batch_size * beam_size):
+            lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf
+        return lprobs

+ 24 - 6
src/seamless_communication/models/inference/translator.py

@@ -18,6 +18,9 @@ from fairseq2.memory import MemoryBlock
 from fairseq2.typing import Device
 from torch import Tensor
 from enum import Enum, auto
+from seamless_communication.models.inference.ngram_repeat_block_processor import (
+    NGramRepeatBlockProcessor,
+)
 
 from seamless_communication.models.unity import (
     UnitTokenizer,
@@ -88,25 +91,38 @@ class Translator(nn.Module):
         input_modality: Modality,
         output_modality: Modality,
         tgt_lang: str,
+        ngram_filtering: bool = False,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
         if input_modality == Modality.TEXT:
             # need to adjust this since src_len is smaller for text.
             max_len_a = 25
         else:
             max_len_a = 1
-
+        text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
+        unit_opts = SequenceGeneratorOptions(
+            beam_size=5, soft_max_seq_len=(max_len_a, 50)
+        )
+        if ngram_filtering:
+            text_opts.logits_processor = NGramRepeatBlockProcessor(
+                no_repeat_ngram_size=4
+            )
+            unit_opts.logits_processor = NGramRepeatBlockProcessor(
+                no_repeat_ngram_size=4
+            )
         generator = UnitYGenerator(
             model,
             text_tokenizer,
             tgt_lang,
             unit_tokenizer if output_modality == Modality.SPEECH else None,
-            text_opts=SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200)),
-            unit_opts=SequenceGeneratorOptions(
-                beam_size=5, soft_max_seq_len=(max_len_a, 50)
-            ),
+            text_opts=text_opts,
+            unit_opts=unit_opts,
         )
         return generator(
-            src["seqs"], src["seq_lens"], input_modality.value, output_modality.value
+            src["seqs"],
+            src["seq_lens"],
+            input_modality.value,
+            output_modality.value,
+            ngram_filtering=ngram_filtering,
         )
 
     def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
@@ -138,6 +154,7 @@ class Translator(nn.Module):
         tgt_lang: str,
         src_lang: Optional[str] = None,
         spkr: Optional[int] = -1,
+        ngram_filtering: bool = False,
     ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
         """
         The main method used to perform inference on all tasks.
@@ -197,6 +214,7 @@ class Translator(nn.Module):
             input_modality,
             output_modality,
             tgt_lang=tgt_lang,
+            ngram_filtering=ngram_filtering,
         )
 
         text_out = result[0]

+ 25 - 1
src/seamless_communication/models/unity/generator.py

@@ -5,7 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List
 
 import torch
 from fairseq2.data.text import TextTokenizer
@@ -25,6 +25,26 @@ from fairseq2.nn.utils.module import infer_device
 from torch import Tensor
 
 
+def remove_consecutive_repeated_ngrams(
+    sequence: List[int], min_size: int = 1, max_size: int = 40
+):
+    assert 1 <= min_size <= max_size
+    drop_idx = set()  # indices that will be dropped from the sequence
+
+    # start from the beginning, check if an ngram of size k (for k=max..min) is
+    # followed by its copy, if so delete the first one, and start over after
+    # the deleted ngram.
+    start = 0
+    while start < len(sequence):
+        for k in range(max_size, min_size - 1, -1):
+            if sequence[start : start + k] == sequence[start + k : start + k + k]:
+                drop_idx |= set(range(start, start + k))
+                start += k - 1  # assumes repeating subsequences don't overlap
+                break
+        start += 1
+    return [token for idx, token in enumerate(sequence) if idx not in drop_idx]
+
+
 class UnitYGenerator:
     """Generates text translations and speech units from a UnitY model."""
 
@@ -127,6 +147,7 @@ class UnitYGenerator:
         source_seq_lens: Optional[Tensor],
         input_modality: str = "speech",
         output_modality: str = "speech",
+        ngram_filtering: bool = False,
     ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
         """
         :param source_seqs:
@@ -192,6 +213,9 @@ class UnitYGenerator:
 
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)
+        if ngram_filtering:
+            units = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
+            units = torch.tensor(units)
 
         unit_output = SequenceToUnitOutput(
             units, unit_gen_output, t2u_encoder_output, t2u_encoder_padding_mask