|
@@ -75,6 +75,14 @@ class SequenceGeneratorOptions:
|
|
|
step_processor: Optional[StepProcessor] = None
|
|
|
"""The processor called at each generation step."""
|
|
|
|
|
|
+ unk_penalty: float = 0.0
|
|
|
+ """The UNK symbol penalty, where values less than 0 produce more UNKs;
|
|
|
+ values greater than 0 produce fewer UNKs."""
|
|
|
+
|
|
|
+ len_penalty: float = 1.0
|
|
|
+ """The length penalty, where values less than 1.0 favor shorter
|
|
|
+ sequences; values greater than 1.0 favor longer sequences."""
|
|
|
+
|
|
|
|
|
|
class UnitYGenerator:
|
|
|
"""Generates text translations and speech units from a UnitY model."""
|
|
@@ -143,6 +151,8 @@ class UnitYGenerator:
|
|
|
max_seq_len=text_opts.hard_max_seq_len,
|
|
|
echo_prompt=True,
|
|
|
step_processors=step_processors,
|
|
|
+ unk_penalty=text_opts.unk_penalty,
|
|
|
+ len_penalty=text_opts.len_penalty,
|
|
|
)
|
|
|
self.s2t_converter = SequenceToTextConverter(
|
|
|
generator, text_tokenizer, "translation", target_lang
|
|
@@ -168,6 +178,8 @@ class UnitYGenerator:
|
|
|
max_seq_len=text_opts.hard_max_seq_len,
|
|
|
echo_prompt=True,
|
|
|
step_processors=step_processors,
|
|
|
+ unk_penalty=text_opts.unk_penalty,
|
|
|
+ len_penalty=text_opts.len_penalty,
|
|
|
)
|
|
|
self.t2t_converter = SequenceToTextConverter(
|
|
|
generator, text_tokenizer, "translation", target_lang
|
|
@@ -208,6 +220,8 @@ class UnitYGenerator:
|
|
|
max_seq_len=unit_opts.hard_max_seq_len,
|
|
|
echo_prompt=True,
|
|
|
step_processors=step_processors,
|
|
|
+ unk_penalty=unit_opts.unk_penalty,
|
|
|
+ len_penalty=unit_opts.len_penalty,
|
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|