|
@@ -138,10 +138,16 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
encoder_padding_mask: Optional[Tensor],
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
- seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
|
|
|
+ seqs, padding_mask = self.text_decoder_frontend(
|
|
|
+ seqs, seq_lens, state_bag=state_bag
|
|
|
+ )
|
|
|
|
|
|
return self.text_decoder( # type: ignore[no-any-return]
|
|
|
- seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
|
|
+ seqs,
|
|
|
+ padding_mask,
|
|
|
+ encoder_output,
|
|
|
+ encoder_padding_mask,
|
|
|
+ state_bag=state_bag,
|
|
|
)
|
|
|
|
|
|
@finaloverride
|
|
@@ -199,10 +205,14 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
encoder_padding_mask: Optional[Tensor],
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
- seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag)
|
|
|
+ seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
|
|
|
|
|
|
return self.decoder( # type: ignore[no-any-return]
|
|
|
- seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
|
|
+ seqs,
|
|
|
+ padding_mask,
|
|
|
+ encoder_output,
|
|
|
+ encoder_padding_mask,
|
|
|
+ state_bag=state_bag,
|
|
|
)
|
|
|
|
|
|
@finaloverride
|
|
@@ -292,10 +302,14 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
encoder_padding_mask: Optional[Tensor],
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
- seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag)
|
|
|
+ seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
|
|
|
|
|
|
return self.decoder( # type: ignore[no-any-return]
|
|
|
- seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
|
|
+ seqs,
|
|
|
+ padding_mask,
|
|
|
+ encoder_output,
|
|
|
+ encoder_padding_mask,
|
|
|
+ state_bag=state_bag,
|
|
|
)
|
|
|
|
|
|
def project(
|