Преглед на файлове

Apply suggestions from code review

Co-authored-by: Kaushik Ram Sadagopan <krs@fb.com>
Ning преди 2 години
родител
ревизия
3a2df0e94f

+ 4 - 5
src/seamless_communication/models/inference/ngram_repeat_block_logits_processor.py

@@ -4,7 +4,7 @@ from torch import Tensor
 import torch
 
 
-class NGramRepeatBlockLogitsProcessor(LogitsProcessor):
+class NGramRepeatBlockProcessor(LogitsProcessor):
     def __init__(self, no_repeat_ngram_size: int) -> None:
         self.no_repeat_ngram_size = no_repeat_ngram_size
 
@@ -12,7 +12,7 @@ class NGramRepeatBlockLogitsProcessor(LogitsProcessor):
         """Remove repeating n-gram tokens."""
         batch_size, beam_size, vocab_size = lprobs.size()
         step_nr = seqs.size(2) - 1
-        # (N, B, step_nr + 1) -> (N * B, step_nr + 1)
+        # (N, B, S) -> (N * B, S)
         seqs = seqs.view(-1, seqs.size(2))
         # (N, B, V) -> (N * B, V)
         lprobs = lprobs.view(-1, vocab_size)
@@ -39,9 +39,8 @@ class NGramRepeatBlockLogitsProcessor(LogitsProcessor):
         :returns:
             modified lprobs tensor with banned tokens set to -inf
         """
-        banned_tokens = [
-            torch.jit.annotate(List[int], []) for _ in range(batch_size * beam_size)
-        ]
+        banned_tokens = [[] for _ in range(batch_size * beam_size)]
+
         if step_nr + 2 - self.no_repeat_ngram_size >= 0:
             cpu_tokens: List[List[int]] = seqs.cpu().tolist()
             check_start_pos = step_nr + 2 - self.no_repeat_ngram_size

+ 3 - 3
src/seamless_communication/models/inference/translator.py

@@ -19,7 +19,7 @@ from fairseq2.typing import Device
 from torch import Tensor
 from enum import Enum, auto
 from seamless_communication.models.inference.ngram_repeat_block_logits_processor import (
-    NGramRepeatBlockLogitsProcessor,
+    NGramRepeatBlockProcessor,
 )
 
 from seamless_communication.models.unity import (
@@ -101,8 +101,8 @@ class Translator(nn.Module):
         text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
         unit_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(max_len_a, 50))
         if ngram_filtering:
-            text_opts.logits_processor = NGramRepeatBlockLogitsProcessor(no_repeat_ngram_size=4)
-            unit_opts.logits_processor = NGramRepeatBlockLogitsProcessor(no_repeat_ngram_size=4)
+            text_opts.logits_processor = NGramRepeatBlockProcessor(no_repeat_ngram_size=4)
+            unit_opts.logits_processor = NGramRepeatBlockProcessor(no_repeat_ngram_size=4)
         generator = UnitYGenerator(
             model,
             text_tokenizer,