Эх сурвалжийг харах

Update LogitsProcessor to StepProcessor (#74)

Can Balioglu 1 жил өмнө
parent
commit
04d9a49462

+ 12 - 14
src/seamless_communication/models/inference/ngram_repeat_block_processor.py

@@ -4,46 +4,45 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from fairseq2.generation.logits_processor import LogitsProcessor as LogitsProcessor
+from fairseq2.generation import StepProcessor
 from typing import List
 from torch import Tensor
 import torch
 
 
-class NGramRepeatBlockProcessor(LogitsProcessor):
+class NGramRepeatBlockProcessor(StepProcessor):
     def __init__(self, no_repeat_ngram_size: int) -> None:
         self.no_repeat_ngram_size = no_repeat_ngram_size
 
-    def __call__(self, seqs: Tensor, lprobs: Tensor) -> None:
+    def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None:
         """Remove repeating n-gram tokens."""
-        batch_size, beam_size, vocab_size = lprobs.size()
+        batch_size, beam_size, vocab_size = probs.size()
         step_nr = seqs.size(2) - 1
         # (N, B, S) -> (N * B, S)
         seqs = seqs.view(-1, seqs.size(2))
         # (N, B, V) -> (N * B, V)
-        lprobs = lprobs.view(-1, vocab_size)
-        self._no_repeat_ngram(seqs, lprobs, batch_size, beam_size, step_nr)
+        probs = probs.view(-1, vocab_size)
+        self._no_repeat_ngram(seqs, probs, lprob, batch_size, beam_size, step_nr)
 
     def _no_repeat_ngram(
         self,
         seqs: Tensor,
-        lprobs: Tensor,
+        probs: Tensor,
+        lprob: bool,
         batch_size: int,
         beam_size: int,
         step_nr: int,
-    ) -> Tensor:
+    ) -> None:
         """For each hypothesis generate a list of previous ngrams
             and set associated lprobs to -inf
 
         :param seqs: The generated sequences of tokens for the first
             `step_nr` steps of decoding (N * B, step_nr + 1)
-        :param lprobs: The next-step log probability reshaped to (N * B, V)
+        :param probs: The next-step probability reshaped to (N * B, V)
+        :param lprob: If ``True``, ``probs`` is log probabilities.
         :param batch_size: The batch size.
         :param beam_size: The beam size.
         :param step_nr: Step number for decoding.
-
-        :returns:
-            modified lprobs tensor with banned tokens set to -inf
         """
         banned_tokens: List[List[int]] = [[] for _ in range(batch_size * beam_size)]
 
@@ -63,5 +62,4 @@ class NGramRepeatBlockProcessor(LogitsProcessor):
                             cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
                         )
         for bbsz_idx in range(batch_size * beam_size):
-            lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf
-        return lprobs
+            probs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf if lprob else 0