|
@@ -18,7 +18,7 @@ from fairseq2.memory import MemoryBlock
|
|
from fairseq2.typing import Device
|
|
from fairseq2.typing import Device
|
|
from torch import Tensor
|
|
from torch import Tensor
|
|
from enum import Enum, auto
|
|
from enum import Enum, auto
|
|
-from seamless_communication.models.inference.ngram_repeat_block_logits_processor import (
|
|
|
|
|
|
+from seamless_communication.models.inference.ngram_repeat_block_processor import (
|
|
NGramRepeatBlockProcessor,
|
|
NGramRepeatBlockProcessor,
|
|
)
|
|
)
|
|
|
|
|
|
@@ -99,10 +99,16 @@ class Translator(nn.Module):
|
|
else:
|
|
else:
|
|
max_len_a = 1
|
|
max_len_a = 1
|
|
text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
|
|
text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
|
|
- unit_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(max_len_a, 50))
|
|
|
|
|
|
+ unit_opts = SequenceGeneratorOptions(
|
|
|
|
+ beam_size=5, soft_max_seq_len=(max_len_a, 50)
|
|
|
|
+ )
|
|
if ngram_filtering:
|
|
if ngram_filtering:
|
|
- text_opts.logits_processor = NGramRepeatBlockProcessor(no_repeat_ngram_size=4)
|
|
|
|
- unit_opts.logits_processor = NGramRepeatBlockProcessor(no_repeat_ngram_size=4)
|
|
|
|
|
|
+ text_opts.logits_processor = NGramRepeatBlockProcessor(
|
|
|
|
+ no_repeat_ngram_size=4
|
|
|
|
+ )
|
|
|
|
+ unit_opts.logits_processor = NGramRepeatBlockProcessor(
|
|
|
|
+ no_repeat_ngram_size=4
|
|
|
|
+ )
|
|
generator = UnitYGenerator(
|
|
generator = UnitYGenerator(
|
|
model,
|
|
model,
|
|
text_tokenizer,
|
|
text_tokenizer,
|
|
@@ -112,7 +118,11 @@ class Translator(nn.Module):
|
|
unit_opts=unit_opts,
|
|
unit_opts=unit_opts,
|
|
)
|
|
)
|
|
return generator(
|
|
return generator(
|
|
- src["seqs"], src["seq_lens"], input_modality.value, output_modality.value, ngram_filtering=ngram_filtering
|
|
|
|
|
|
+ src["seqs"],
|
|
|
|
+ src["seq_lens"],
|
|
|
|
+ input_modality.value,
|
|
|
|
+ output_modality.value,
|
|
|
|
+ ngram_filtering=ngram_filtering,
|
|
)
|
|
)
|
|
|
|
|
|
def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
|
|
def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
|
|
@@ -204,7 +214,7 @@ class Translator(nn.Module):
|
|
input_modality,
|
|
input_modality,
|
|
output_modality,
|
|
output_modality,
|
|
tgt_lang=tgt_lang,
|
|
tgt_lang=tgt_lang,
|
|
- ngram_filtering=ngram_filtering
|
|
|
|
|
|
+ ngram_filtering=ngram_filtering,
|
|
)
|
|
)
|
|
|
|
|
|
text_out = result[0]
|
|
text_out = result[0]
|