소스 검색

Start using fairseq2's NGramRepeatBlockProcessor (#100)

Can Balioglu 1 년 전
부모
커밋
29935570cd

+ 4 - 4
src/seamless_communication/cli/m4t/predict/predict.py

@@ -10,9 +10,9 @@ from typing import Tuple
 
 import torch
 import torchaudio
-from fairseq2.generation import SequenceGeneratorOptions
+from fairseq2.generation import NGramRepeatBlockProcessor, SequenceGeneratorOptions
 
-from seamless_communication.inference import NGramRepeatBlockProcessor, Translator
+from seamless_communication.inference import Translator
 
 logging.basicConfig(
     level=logging.INFO,
@@ -150,7 +150,7 @@ def set_generation_opts(
     )
     if args.text_generation_ngram_blocking:
         text_generation_opts.step_processor = NGramRepeatBlockProcessor(
-            no_repeat_ngram_size=args.no_repeat_ngram_size
+            ngram_size=args.no_repeat_ngram_size
         )
 
     unit_generation_opts = SequenceGeneratorOptions(
@@ -162,7 +162,7 @@ def set_generation_opts(
     )
     if args.unit_generation_ngram_blocking:
         unit_generation_opts.step_processor = NGramRepeatBlockProcessor(
-            no_repeat_ngram_size=args.no_repeat_ngram_size
+            ngram_size=args.no_repeat_ngram_size
         )
     return text_generation_opts, unit_generation_opts
 

+ 0 - 3
src/seamless_communication/inference/__init__.py

@@ -5,9 +5,6 @@
 # LICENSE file in the root directory of this source tree.
 
 from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
-from seamless_communication.inference.ngram_repeat_block_processor import (
-    NGramRepeatBlockProcessor as NGramRepeatBlockProcessor,
-)
 from seamless_communication.inference.translator import (
     BatchedSpeechOutput as BatchedSpeechOutput,
 )

+ 0 - 66
src/seamless_communication/inference/ngram_repeat_block_processor.py

@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import List
-
-import torch
-from fairseq2.generation import StepProcessor
-from torch import Tensor
-
-
-class NGramRepeatBlockProcessor(StepProcessor):
-    def __init__(self, no_repeat_ngram_size: int) -> None:
-        self.no_repeat_ngram_size = no_repeat_ngram_size
-
-    def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None:
-        """Remove repeating n-gram tokens."""
-        batch_size, beam_size, vocab_size = probs.size()
-        step_nr = seqs.size(2) - 1
-        # (N, B, S) -> (N * B, S)
-        seqs = seqs.view(-1, seqs.size(2))
-        # (N, B, V) -> (N * B, V)
-        probs = probs.view(-1, vocab_size)
-        self._no_repeat_ngram(seqs, probs, lprob, batch_size, beam_size, step_nr)
-
-    def _no_repeat_ngram(
-        self,
-        seqs: Tensor,
-        probs: Tensor,
-        lprob: bool,
-        batch_size: int,
-        beam_size: int,
-        step_nr: int,
-    ) -> None:
-        """For each hypothesis generate a list of previous ngrams
-            and set associated lprobs to -inf
-
-        :param seqs: The generated sequences of tokens for the first
-            `step_nr` steps of decoding (N * B, step_nr + 1)
-        :param probs: The next-step probability reshaped to (N * B, V)
-        :param lprob: If ``True``, ``probs`` is log probabilities.
-        :param batch_size: The batch size.
-        :param beam_size: The beam size.
-        :param step_nr: Step number for decoding.
-        """
-        banned_tokens: List[List[int]] = [[] for _ in range(batch_size * beam_size)]
-
-        if step_nr + 2 - self.no_repeat_ngram_size >= 0:
-            cpu_tokens: List[List[int]] = seqs.cpu().tolist()
-            check_start_pos = step_nr + 2 - self.no_repeat_ngram_size
-            for bbsz_idx in range(batch_size * beam_size):
-                ngram_to_check = cpu_tokens[bbsz_idx][
-                    -(self.no_repeat_ngram_size - 1) :
-                ]
-                for i in range(check_start_pos):
-                    if (
-                        ngram_to_check
-                        == cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1]
-                    ):
-                        banned_tokens[bbsz_idx].append(
-                            cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
-                        )
-        for bbsz_idx in range(batch_size * beam_size):
-            probs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf if lprob else 0