|
@@ -31,15 +31,15 @@ class UnitYGenerator:
|
|
model: UnitYModel
|
|
model: UnitYModel
|
|
s2t_generator: SequenceToTextGenerator
|
|
s2t_generator: SequenceToTextGenerator
|
|
t2t_generator: Optional[SequenceToTextGenerator]
|
|
t2t_generator: Optional[SequenceToTextGenerator]
|
|
- unit_decoder: UnitTokenDecoder
|
|
|
|
- unit_generator: Seq2SeqGenerator
|
|
|
|
|
|
+ unit_decoder: Optional[UnitTokenDecoder]
|
|
|
|
+ unit_generator: Optional[Seq2SeqGenerator]
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
model: UnitYModel,
|
|
model: UnitYModel,
|
|
text_tokenizer: TextTokenizer,
|
|
text_tokenizer: TextTokenizer,
|
|
- unit_tokenizer: UnitTokenizer,
|
|
|
|
target_lang: str,
|
|
target_lang: str,
|
|
|
|
+ unit_tokenizer: Optional[UnitTokenizer] = None,
|
|
text_opts: Optional[SequenceGeneratorOptions] = None,
|
|
text_opts: Optional[SequenceGeneratorOptions] = None,
|
|
unit_opts: Optional[SequenceGeneratorOptions] = None,
|
|
unit_opts: Optional[SequenceGeneratorOptions] = None,
|
|
) -> None:
|
|
) -> None:
|
|
@@ -97,25 +97,28 @@ class UnitYGenerator:
|
|
t2t_model, text_tokenizer, target_lang, text_opts
|
|
t2t_model, text_tokenizer, target_lang, text_opts
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ self.unit_generator = None
|
|
|
|
+ self.unit_decoder = None
|
|
# Set up unit generator.
|
|
# Set up unit generator.
|
|
- self.unit_decoder = unit_tokenizer.create_decoder()
|
|
|
|
|
|
+ if unit_tokenizer is not None:
|
|
|
|
+ self.unit_decoder = unit_tokenizer.create_decoder()
|
|
|
|
|
|
- unit_encoder = unit_tokenizer.create_encoder(
|
|
|
|
- lang=target_lang, device=infer_device(model.t2u_model)
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if unit_opts is None:
|
|
|
|
- # Speech sequences are typically much longer than text sequences.
|
|
|
|
- unit_opts = SequenceGeneratorOptions(
|
|
|
|
- soft_max_seq_len=(1, 50), hard_max_seq_len=5000
|
|
|
|
|
|
+ unit_encoder = unit_tokenizer.create_encoder(
|
|
|
|
+ lang=target_lang, device=infer_device(model.t2u_model)
|
|
)
|
|
)
|
|
|
|
|
|
- self.unit_generator = Seq2SeqGenerator(
|
|
|
|
- model.t2u_model,
|
|
|
|
- unit_tokenizer.vocab_info,
|
|
|
|
- unit_encoder.prefix_indices,
|
|
|
|
- unit_opts,
|
|
|
|
- )
|
|
|
|
|
|
+ if unit_opts is None:
|
|
|
|
+ # Speech sequences are typically much longer than text sequences.
|
|
|
|
+ unit_opts = SequenceGeneratorOptions(
|
|
|
|
+ soft_max_seq_len=(1, 50), hard_max_seq_len=5000
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ self.unit_generator = Seq2SeqGenerator(
|
|
|
|
+ model.t2u_model,
|
|
|
|
+ unit_tokenizer.vocab_info,
|
|
|
|
+ unit_encoder.prefix_indices,
|
|
|
|
+ unit_opts,
|
|
|
|
+ )
|
|
|
|
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|
|
def __call__(
|
|
def __call__(
|
|
@@ -176,6 +179,9 @@ class UnitYGenerator:
|
|
decoder_output, decoder_padding_mask
|
|
decoder_output, decoder_padding_mask
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ assert self.unit_generator is not None
|
|
|
|
+ assert self.unit_decoder is not None
|
|
|
|
+
|
|
unit_gen_output = self.unit_generator(
|
|
unit_gen_output = self.unit_generator(
|
|
t2u_encoder_output,
|
|
t2u_encoder_output,
|
|
t2u_encoder_padding_mask,
|
|
t2u_encoder_padding_mask,
|