Răsfoiți Sursa

ngram fixes

cndn 2 ani în urmă
părinte
comite
55a251dfa6

+ 62 - 0
src/seamless_communication/models/inference/ngram_repeat_block_logits_processor.py

@@ -0,0 +1,62 @@
+from fairseq2.generation.logits_processor import LogitsProcessor as LogitsProcessor
+from typing import List
+from torch import Tensor
+import torch
+
+
+class NGramRepeatBlockLogitsProcessor(LogitsProcessor):
+    def __init__(self, no_repeat_ngram_size: int) -> None:
+        self.no_repeat_ngram_size = no_repeat_ngram_size
+
+    def __call__(self, seqs: Tensor, lprobs: Tensor) -> None:
+        """Remove repeating n-gram tokens."""
+        batch_size, beam_size, vocab_size = lprobs.size()
+        step_nr = seqs.size(2) - 1
+        # (N, B, step_nr + 1) -> (N * B, step_nr + 1)
+        seqs = seqs.view(-1, seqs.size(2))
+        # (N, B, V) -> (N * B, V)
+        lprobs = lprobs.view(-1, vocab_size)
+        self._no_repeat_ngram(seqs, lprobs, batch_size, beam_size, step_nr)
+
+    def _no_repeat_ngram(
+        self,
+        seqs: Tensor,
+        lprobs: Tensor,
+        batch_size: int,
+        beam_size: int,
+        step_nr: int,
+    ) -> Tensor:
+        """For each hypothesis generate a list of previous ngrams
+            and set associated lprobs to -inf
+
+        :param seqs: The generated sequences of tokens for the first
+            `step_nr` steps of decoding (N * B, step_nr + 1)
+        :param lprobs: The next-step log probability reshaped to (N * B, V)
+        :param batch_size: The batch size.
+        :param beam_size: The beam size.
+        :param step_nr: Step number for decoding.
+
+        :returns:
+            modified lprobs tensor with banned tokens set to -inf
+        """
+        banned_tokens = [
+            torch.jit.annotate(List[int], []) for _ in range(batch_size * beam_size)
+        ]
+        if step_nr + 2 - self.no_repeat_ngram_size >= 0:
+            cpu_tokens: List[List[int]] = seqs.cpu().tolist()
+            check_start_pos = step_nr + 2 - self.no_repeat_ngram_size
+            for bbsz_idx in range(batch_size * beam_size):
+                ngram_to_check = cpu_tokens[bbsz_idx][
+                    -(self.no_repeat_ngram_size - 1) :
+                ]
+                for i in range(check_start_pos):
+                    if (
+                        ngram_to_check
+                        == cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1]
+                    ):
+                        banned_tokens[bbsz_idx].append(
+                            cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
+                        )
+        for bbsz_idx in range(batch_size * beam_size):
+            lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf
+        return lprobs

+ 47 - 6
src/seamless_communication/models/inference/translator.py

@@ -18,6 +18,9 @@ from fairseq2.memory import MemoryBlock
 from fairseq2.typing import Device
 from torch import Tensor
 from enum import Enum, auto
+from seamless_communication.models.inference.ngram_repeat_block_logits_processor import (
+    NGramRepeatBlockLogitsProcessor,
+)
 
 from seamless_communication.models.unity import (
     UnitTokenizer,
@@ -30,6 +33,13 @@ from seamless_communication.models.unity import (
 from seamless_communication.models.unity.generator import SequenceToUnitOutput
 from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
 
+import urllib.request
+from urllib.request import urlopen
+import ssl
+import json
+
+ssl._create_default_https_context = ssl._create_unverified_context
+
 
 class Task(Enum):
     S2ST = auto()
@@ -88,25 +98,38 @@ class Translator(nn.Module):
         input_modality: Modality,
         output_modality: Modality,
         tgt_lang: str,
+        ngram_filtering: bool = False,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
         if input_modality == Modality.TEXT:
             # need to adjust this since src_len is smaller for text.
             max_len_a = 25
         else:
             max_len_a = 1
-
+        text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
+        unit_opts = SequenceGeneratorOptions(
+            beam_size=5, soft_max_seq_len=(max_len_a, 50)
+        )
+        if ngram_filtering:
+            text_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
+                no_repeat_ngram_size=10
+            )
+            unit_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
+                no_repeat_ngram_size=10
+            )
         generator = UnitYGenerator(
             model,
             text_tokenizer,
             tgt_lang,
             unit_tokenizer if output_modality == Modality.SPEECH else None,
-            text_opts=SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200)),
-            unit_opts=SequenceGeneratorOptions(
-                beam_size=5, soft_max_seq_len=(max_len_a, 50)
-            ),
+            text_opts=text_opts,
+            unit_opts=unit_opts,
         )
         return generator(
-            src["seqs"], src["seq_lens"], input_modality.value, output_modality.value
+            src["seqs"],
+            src["seq_lens"],
+            input_modality.value,
+            output_modality.value,
+            ngram_filtering=ngram_filtering,
         )
 
     def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
@@ -138,6 +161,7 @@ class Translator(nn.Module):
         tgt_lang: str,
         src_lang: Optional[str] = None,
         spkr: Optional[int] = -1,
+        ngram_filtering: bool = False,
     ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
         """
         The main method used to perform inference on all tasks.
@@ -197,6 +221,7 @@ class Translator(nn.Module):
             input_modality,
             output_modality,
             tgt_lang=tgt_lang,
+            ngram_filtering=ngram_filtering,
         )
 
         text_out = result[0]
@@ -207,3 +232,19 @@ class Translator(nn.Module):
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
             wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
             return text_out.sentences[0], wav_out, sr_out
+
+
+if __name__ == "__main__":
+    import torchaudio
+
+    # audio = "/data/home/dnn/LJ003-0001.wav"
+    audio = "/data/home/dnn/oss_sc/seamless_communication/spanish_repeat.wav"
+    translator = Translator(
+        "seamlessM4T_large", "vocoder_36langs", torch.device("cuda:0")
+    )
+    text_out, wav, sr = translator.predict(audio, "s2st", "deu", ngram_filtering=True)  # type: ignore
+    torchaudio.save(
+        "/data/home/dnn/deu_testing.wav",
+        wav[0].cpu(),
+        sample_rate=sr,
+    )

+ 25 - 1
src/seamless_communication/models/unity/generator.py

@@ -5,7 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List
 
 import torch
 from fairseq2.data.text import TextTokenizer
@@ -25,6 +25,26 @@ from fairseq2.nn.utils.module import infer_device
 from torch import Tensor
 
 
+def remove_consecutive_repeated_ngrams(
+    sequence: List[int], min_size: int = 1, max_size: int = 40
+):
+    assert 1 <= min_size <= max_size
+    drop_idx = set()  # indices that will be dropped from the sequence
+
+    # start from the beginning, check if an ngram of size k (for k=max..min) is
+    # followed by its copy, if so delete the first one, and start over after
+    # the deleted ngram.
+    start = 0
+    while start < len(sequence):
+        for k in range(max_size, min_size - 1, -1):
+            if sequence[start : start + k] == sequence[start + k : start + k + k]:
+                drop_idx |= set(range(start, start + k))
+                start += k - 1  # assumes repeating subsequences don't overlap
+                break
+        start += 1
+    return [token for idx, token in enumerate(sequence) if idx not in drop_idx]
+
+
 class UnitYGenerator:
     """Generates text translations and speech units from a UnitY model."""
 
@@ -127,6 +147,7 @@ class UnitYGenerator:
         source_seq_lens: Optional[Tensor],
         input_modality: str = "speech",
         output_modality: str = "speech",
+        ngram_filtering: bool = False,
     ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
         """
         :param source_seqs:
@@ -192,6 +213,9 @@ class UnitYGenerator:
 
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)
+        if ngram_filtering:
+            units = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
+            units = torch.tensor(units)
 
         unit_output = SequenceToUnitOutput(
             units, unit_gen_output, t2u_encoder_output, t2u_encoder_padding_mask