|
@@ -102,16 +102,28 @@ class Translator(nn.Module):
|
|
|
output_modality: Modality,
|
|
|
tgt_lang: str,
|
|
|
ngram_filtering: bool = False,
|
|
|
+ text_max_len_a: int = 1,
|
|
|
+ text_max_len_b: int = 200,
|
|
|
+ unit_max_len_a: Optional[int] = None,
|
|
|
+ unit_max_len_b: Optional[int] = None,
|
|
|
) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
|
|
|
- if input_modality == Modality.TEXT:
|
|
|
- # need to adjust this since src_len is smaller for text.
|
|
|
- max_len_a = 25
|
|
|
- else:
|
|
|
- max_len_a = 1
|
|
|
- text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
|
|
|
+
|
|
|
+ if unit_max_len_a is None:
|
|
|
+ # need to adjust this for T2ST since src_len is smaller for text.
|
|
|
+ if input_modality == Modality.TEXT:
|
|
|
+ unit_max_len_a = 25
|
|
|
+ else:
|
|
|
+ unit_max_len_a = 1
|
|
|
+
|
|
|
+ text_opts = SequenceGeneratorOptions(
|
|
|
+ beam_size=5,
|
|
|
+ soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
|
|
+ )
|
|
|
unit_opts = SequenceGeneratorOptions(
|
|
|
- beam_size=5, soft_max_seq_len=(max_len_a, 50)
|
|
|
+ beam_size=5,
|
|
|
+ soft_max_seq_len=(unit_max_len_a, unit_max_len_b or 50)
|
|
|
)
|
|
|
+
|
|
|
if ngram_filtering:
|
|
|
text_opts.logits_processor = NGramRepeatBlockProcessor(
|
|
|
no_repeat_ngram_size=4
|
|
@@ -156,6 +168,10 @@ class Translator(nn.Module):
|
|
|
spkr: Optional[int] = -1,
|
|
|
ngram_filtering: bool = False,
|
|
|
sample_rate: int = 16000,
|
|
|
+ text_max_len_a: int = 1,
|
|
|
+ text_max_len_b: int = 200,
|
|
|
+ unit_max_len_a: Optional[int] = None,
|
|
|
+ unit_max_len_b: Optional[int] = None,
|
|
|
) -> Tuple[StringLike, Optional[Tensor], Optional[int]]:
|
|
|
"""
|
|
|
The main method used to perform inference on all tasks.
|
|
@@ -216,6 +232,10 @@ class Translator(nn.Module):
|
|
|
output_modality,
|
|
|
tgt_lang=tgt_lang,
|
|
|
ngram_filtering=ngram_filtering,
|
|
|
+ text_max_len_a=text_max_len_a,
|
|
|
+ text_max_len_b=text_max_len_b,
|
|
|
+ unit_max_len_a=unit_max_len_a,
|
|
|
+ unit_max_len_b=unit_max_len_b,
|
|
|
)
|
|
|
|
|
|
text_out = result[0]
|