|
@@ -383,7 +383,7 @@ class UnitYNART2UModel(Module):
|
|
|
text_seqs: Optional[Tensor],
|
|
|
duration_factor: float = 1.0,
|
|
|
film_cond_emb: Optional[Tensor] = None,
|
|
|
- ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
|
|
|
+ ) -> Tuple[SequenceModelOutput, Optional[PaddingMask], Tensor]:
|
|
|
encoder_output, encoder_padding_mask = self.encode(
|
|
|
text_decoder_output, text_decoder_padding_mask
|
|
|
)
|
|
@@ -391,7 +391,7 @@ class UnitYNART2UModel(Module):
|
|
|
if self.prosody_proj is not None and film_cond_emb is not None:
|
|
|
encoder_output = encoder_output + self.prosody_proj(film_cond_emb)
|
|
|
|
|
|
- decoder_output, decoder_padding_mask = self.decode(
|
|
|
+ decoder_output, decoder_padding_mask, durations = self.decode(
|
|
|
encoder_output,
|
|
|
encoder_padding_mask,
|
|
|
text_seqs,
|
|
@@ -399,7 +399,7 @@ class UnitYNART2UModel(Module):
|
|
|
film_cond_emb,
|
|
|
)
|
|
|
|
|
|
- return self.project(decoder_output), decoder_padding_mask
|
|
|
+ return self.project(decoder_output), decoder_padding_mask, durations
|
|
|
|
|
|
def encode(
|
|
|
self,
|
|
@@ -418,10 +418,10 @@ class UnitYNART2UModel(Module):
|
|
|
text_seqs: Optional[Tensor],
|
|
|
duration_factor: float = 1.0,
|
|
|
film_cond_emb: Optional[Tensor] = None,
|
|
|
- ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
|
|
|
# encoder_output: (N, S, M)
|
|
|
# text_seqs: (N, S)
|
|
|
- seqs, padding_mask = self.decoder_frontend(
|
|
|
+ seqs, padding_mask, durations = self.decoder_frontend(
|
|
|
encoder_output,
|
|
|
encoder_padding_mask,
|
|
|
text_seqs,
|
|
@@ -429,7 +429,11 @@ class UnitYNART2UModel(Module):
|
|
|
film_cond_emb,
|
|
|
)
|
|
|
|
|
|
- return self.decoder(seqs, padding_mask, film_cond_emb=film_cond_emb) # type: ignore[no-any-return]
|
|
|
+ seqs, padding_mask = self.decoder(
|
|
|
+ seqs, padding_mask, film_cond_emb=film_cond_emb
|
|
|
+ )
|
|
|
+
|
|
|
+ return seqs, padding_mask, durations # type: ignore[no-any-return]
|
|
|
|
|
|
def project(self, decoder_output: Tensor) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|