|
@@ -4,46 +4,45 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
-from fairseq2.generation.logits_processor import LogitsProcessor as LogitsProcessor
|
|
|
+from fairseq2.generation import StepProcessor
|
|
|
from typing import List
|
|
|
from torch import Tensor
|
|
|
import torch
|
|
|
|
|
|
|
|
|
-class NGramRepeatBlockProcessor(LogitsProcessor):
|
|
|
+class NGramRepeatBlockProcessor(StepProcessor):
|
|
|
def __init__(self, no_repeat_ngram_size: int) -> None:
|
|
|
self.no_repeat_ngram_size = no_repeat_ngram_size
|
|
|
|
|
|
- def __call__(self, seqs: Tensor, lprobs: Tensor) -> None:
|
|
|
+ def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None:
|
|
|
"""Remove repeating n-gram tokens."""
|
|
|
- batch_size, beam_size, vocab_size = lprobs.size()
|
|
|
+ batch_size, beam_size, vocab_size = probs.size()
|
|
|
step_nr = seqs.size(2) - 1
|
|
|
# (N, B, S) -> (N * B, S)
|
|
|
seqs = seqs.view(-1, seqs.size(2))
|
|
|
# (N, B, V) -> (N * B, V)
|
|
|
- lprobs = lprobs.view(-1, vocab_size)
|
|
|
- self._no_repeat_ngram(seqs, lprobs, batch_size, beam_size, step_nr)
|
|
|
+ probs = probs.view(-1, vocab_size)
|
|
|
+ self._no_repeat_ngram(seqs, probs, lprob, batch_size, beam_size, step_nr)
|
|
|
|
|
|
def _no_repeat_ngram(
|
|
|
self,
|
|
|
seqs: Tensor,
|
|
|
- lprobs: Tensor,
|
|
|
+ probs: Tensor,
|
|
|
+ lprob: bool,
|
|
|
batch_size: int,
|
|
|
beam_size: int,
|
|
|
step_nr: int,
|
|
|
- ) -> Tensor:
|
|
|
+ ) -> None:
|
|
|
"""For each hypothesis generate a list of previous ngrams
|
|
|
and set associated lprobs to -inf
|
|
|
|
|
|
:param seqs: The generated sequences of tokens for the first
|
|
|
`step_nr` steps of decoding (N * B, step_nr + 1)
|
|
|
- :param lprobs: The next-step log probability reshaped to (N * B, V)
|
|
|
+ :param probs: The next-step probability reshaped to (N * B, V)
|
|
|
+ :param lprob: If ``True``, ``probs`` is log probabilities.
|
|
|
:param batch_size: The batch size.
|
|
|
:param beam_size: The beam size.
|
|
|
:param step_nr: Step number for decoding.
|
|
|
-
|
|
|
- :returns:
|
|
|
- modified lprobs tensor with banned tokens set to -inf
|
|
|
"""
|
|
|
banned_tokens: List[List[int]] = [[] for _ in range(batch_size * beam_size)]
|
|
|
|
|
@@ -63,5 +62,4 @@ class NGramRepeatBlockProcessor(LogitsProcessor):
|
|
|
cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
|
|
|
)
|
|
|
for bbsz_idx in range(batch_size * beam_size):
|
|
|
- lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf
|
|
|
- return lprobs
|
|
|
+ probs[bbsz_idx, banned_tokens[bbsz_idx]] = -torch.inf if lprob else 0
|