|
@@ -379,7 +379,7 @@ class UnitYNART2UModel(Module):
|
|
|
target_seqs: Optional[Tensor],
|
|
target_seqs: Optional[Tensor],
|
|
|
target_seq_lens: Optional[Tensor],
|
|
target_seq_lens: Optional[Tensor],
|
|
|
text_seqs: Optional[Tensor],
|
|
text_seqs: Optional[Tensor],
|
|
|
- ) -> SequenceModelOutput:
|
|
|
|
|
|
|
+ ) -> Tuple[SequenceModelOutput, Optional[Tensor]]:
|
|
|
encoder_output, encoder_padding_mask = self.encode(
|
|
encoder_output, encoder_padding_mask = self.encode(
|
|
|
text_decoder_output, text_decoder_padding_mask
|
|
text_decoder_output, text_decoder_padding_mask
|
|
|
)
|
|
)
|
|
@@ -392,7 +392,7 @@ class UnitYNART2UModel(Module):
|
|
|
text_seqs,
|
|
text_seqs,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- return self.project(decoder_output, decoder_padding_mask)
|
|
|
|
|
|
|
+ return self.project(decoder_output), decoder_padding_mask
|
|
|
|
|
|
|
|
def encode(
|
|
def encode(
|
|
|
self,
|
|
self,
|
|
@@ -420,9 +420,7 @@ class UnitYNART2UModel(Module):
|
|
|
|
|
|
|
|
return self.decoder(seqs, padding_mask) # type: ignore[no-any-return]
|
|
return self.decoder(seqs, padding_mask) # type: ignore[no-any-return]
|
|
|
|
|
|
|
|
- def project(
|
|
|
|
|
- self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
|
|
|
|
|
- ) -> SequenceModelOutput:
|
|
|
|
|
|
|
+ def project(self, decoder_output: Tensor) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
|
|
|
return SequenceModelOutput(logits, self.pad_idx)
|
|
return SequenceModelOutput(logits, self.pad_idx)
|