|
@@ -39,8 +39,8 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
speech_encoder: TransformerEncoder
|
|
|
text_encoder_frontend: Optional[TransformerFrontend]
|
|
|
text_encoder: Optional[TransformerEncoder]
|
|
|
- text_decoder_frontend: TransformerFrontend
|
|
|
- text_decoder: TransformerDecoder
|
|
|
+ text_decoder_frontend: Optional[TransformerFrontend]
|
|
|
+ text_decoder: Optional[TransformerDecoder]
|
|
|
final_proj: Projection
|
|
|
t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
|
|
|
prosody_encoder_model: Optional[ECAPA_TDNN]
|
|
@@ -51,8 +51,8 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
speech_encoder: TransformerEncoder,
|
|
|
text_encoder_frontend: Optional[TransformerFrontend],
|
|
|
text_encoder: Optional[TransformerEncoder],
|
|
|
- text_decoder_frontend: TransformerFrontend,
|
|
|
- text_decoder: TransformerDecoder,
|
|
|
+ text_decoder_frontend: Optional[TransformerFrontend],
|
|
|
+ text_decoder: Optional[TransformerDecoder],
|
|
|
final_proj: Projection,
|
|
|
t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
|
|
|
target_vocab_info: VocabularyInfo,
|
|
@@ -85,8 +85,22 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
self.register_module("text_encoder_frontend", None)
|
|
|
self.register_module("text_encoder", None)
|
|
|
|
|
|
- self.text_decoder_frontend = text_decoder_frontend
|
|
|
- self.text_decoder = text_decoder
|
|
|
+ if text_decoder is not None:
|
|
|
+ if text_decoder_frontend is None:
|
|
|
+ raise ValueError(
|
|
|
+ "Both `text_decoder` and `text_decoder_frontend` must be specified, but `text_decoder_frontend` is `None`."
|
|
|
+ )
|
|
|
+
|
|
|
+ self.text_decoder_frontend = text_decoder_frontend
|
|
|
+ self.text_decoder = text_decoder
|
|
|
+ else:
|
|
|
+ if text_decoder_frontend is not None:
|
|
|
+ raise ValueError(
|
|
|
+ "Both `text_encoder` and `text_encoder_frontend` must be specified, but `text_decoder` is `None`."
|
|
|
+ )
|
|
|
+
|
|
|
+ self.register_module("text_decoder_frontend", None)
|
|
|
+ self.register_module("text_decoder", None)
|
|
|
|
|
|
self.final_proj = final_proj
|
|
|
|
|
@@ -146,6 +160,13 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
*,
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
+ if self.text_decoder is None:
|
|
|
+ raise ValueError(
|
|
|
+ "`decode()` requires a text decoder, but the current UnitY model does not have one."
|
|
|
+ )
|
|
|
+
|
|
|
+ assert self.text_decoder_frontend is not None
|
|
|
+
|
|
|
seqs, padding_mask = self.text_decoder_frontend(
|
|
|
seqs, padding_mask, state_bag=state_bag
|
|
|
)
|