|
@@ -4,7 +4,7 @@ from torch import Tensor
|
|
import torch
|
|
import torch
|
|
|
|
|
|
|
|
|
|
-class NGramRepeatBlockLogitsProcessor(LogitsProcessor):
|
|
|
|
|
|
+class NGramRepeatBlockProcessor(LogitsProcessor):
|
|
def __init__(self, no_repeat_ngram_size: int) -> None:
|
|
def __init__(self, no_repeat_ngram_size: int) -> None:
|
|
self.no_repeat_ngram_size = no_repeat_ngram_size
|
|
self.no_repeat_ngram_size = no_repeat_ngram_size
|
|
|
|
|
|
@@ -12,7 +12,7 @@ class NGramRepeatBlockLogitsProcessor(LogitsProcessor):
|
|
"""Remove repeating n-gram tokens."""
|
|
"""Remove repeating n-gram tokens."""
|
|
batch_size, beam_size, vocab_size = lprobs.size()
|
|
batch_size, beam_size, vocab_size = lprobs.size()
|
|
step_nr = seqs.size(2) - 1
|
|
step_nr = seqs.size(2) - 1
|
|
- # (N, B, step_nr + 1) -> (N * B, step_nr + 1)
|
|
|
|
|
|
+ # (N, B, S) -> (N * B, S)
|
|
seqs = seqs.view(-1, seqs.size(2))
|
|
seqs = seqs.view(-1, seqs.size(2))
|
|
# (N, B, V) -> (N * B, V)
|
|
# (N, B, V) -> (N * B, V)
|
|
lprobs = lprobs.view(-1, vocab_size)
|
|
lprobs = lprobs.view(-1, vocab_size)
|
|
@@ -39,9 +39,8 @@ class NGramRepeatBlockLogitsProcessor(LogitsProcessor):
|
|
:returns:
|
|
:returns:
|
|
modified lprobs tensor with banned tokens set to -inf
|
|
modified lprobs tensor with banned tokens set to -inf
|
|
"""
|
|
"""
|
|
- banned_tokens = [
|
|
|
|
- torch.jit.annotate(List[int], []) for _ in range(batch_size * beam_size)
|
|
|
|
- ]
|
|
|
|
|
|
+ banned_tokens = [[] for _ in range(batch_size * beam_size)]
|
|
|
|
+
|
|
if step_nr + 2 - self.no_repeat_ngram_size >= 0:
|
|
if step_nr + 2 - self.no_repeat_ngram_size >= 0:
|
|
cpu_tokens: List[List[int]] = seqs.cpu().tolist()
|
|
cpu_tokens: List[List[int]] = seqs.cpu().tolist()
|
|
check_start_pos = step_nr + 2 - self.no_repeat_ngram_size
|
|
check_start_pos = step_nr + 2 - self.no_repeat_ngram_size
|