|
@@ -11,6 +11,7 @@ from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
|
|
|
from fairseq2.models.sequence import SequenceModelOutput
|
|
|
from fairseq2.models.transformer.frontend import TransformerFrontend
|
|
|
from fairseq2.nn.incremental_state import IncrementalStateBag
|
|
|
+from fairseq2.nn.padding import PaddingMask
|
|
|
from fairseq2.nn.projection import Projection
|
|
|
from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
|
|
|
from fairseq2.nn.utils.module import check_model_dim
|
|
@@ -97,28 +98,28 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
|
|
|
@finaloverride
|
|
|
def encode(
|
|
|
- self, seqs: Tensor, seq_lens: Optional[Tensor]
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
+ self, seqs: Tensor, padding_mask: Optional[PaddingMask]
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
if self.input_modality == "speech":
|
|
|
- return self.encode_speech(seqs, seq_lens)
|
|
|
+ return self.encode_speech(seqs, padding_mask)
|
|
|
|
|
|
if self.input_modality == "text":
|
|
|
- return self.encode_text(seqs, seq_lens)
|
|
|
+ return self.encode_text(seqs, padding_mask)
|
|
|
|
|
|
raise RuntimeError(
|
|
|
f"`input_modality` must be 'speech' or 'text', but is '{self.input_modality}' instead."
|
|
|
)
|
|
|
|
|
|
def encode_speech(
|
|
|
- self, seqs: Tensor, seq_lens: Optional[Tensor]
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
- seqs, padding_mask = self.speech_encoder_frontend(seqs, seq_lens)
|
|
|
+ self, seqs: Tensor, padding_mask: Optional[PaddingMask]
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
+ seqs, padding_mask = self.speech_encoder_frontend(seqs, padding_mask)
|
|
|
|
|
|
return self.speech_encoder(seqs, padding_mask) # type: ignore[no-any-return]
|
|
|
|
|
|
def encode_text(
|
|
|
- self, seqs: Tensor, seq_lens: Optional[Tensor]
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
+ self, seqs: Tensor, padding_mask: Optional[PaddingMask]
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
if self.text_encoder is None:
|
|
|
raise ValueError(
|
|
|
"`encode_text()` requires a text encoder, but the current UnitY model does not have one."
|
|
@@ -126,7 +127,7 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
|
|
|
assert self.text_encoder_frontend is not None
|
|
|
|
|
|
- seqs, padding_mask = self.text_encoder_frontend(seqs, seq_lens)
|
|
|
+ seqs, padding_mask = self.text_encoder_frontend(seqs, padding_mask)
|
|
|
|
|
|
return self.text_encoder(seqs, padding_mask) # type: ignore[no-any-return]
|
|
|
|
|
@@ -134,13 +135,13 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
def decode(
|
|
|
self,
|
|
|
seqs: Tensor,
|
|
|
- seq_lens: Optional[Tensor],
|
|
|
+ padding_mask: Optional[PaddingMask],
|
|
|
encoder_output: Tensor,
|
|
|
- encoder_padding_mask: Optional[Tensor],
|
|
|
+ encoder_padding_mask: Optional[PaddingMask],
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
seqs, padding_mask = self.text_decoder_frontend(
|
|
|
- seqs, seq_lens, state_bag=state_bag
|
|
|
+ seqs, padding_mask, state_bag=state_bag
|
|
|
)
|
|
|
|
|
|
return self.text_decoder( # type: ignore[no-any-return]
|
|
@@ -153,7 +154,7 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
|
|
|
@finaloverride
|
|
|
def project(
|
|
|
- self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
|
|
|
+ self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
|
|
|
) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
@@ -192,21 +193,23 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
|
|
|
@finaloverride
|
|
|
def encode(
|
|
|
- self, seqs: Tensor, seq_lens: Optional[Tensor]
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
- seqs, padding_mask = self.encoder_frontend(seqs, seq_lens)
|
|
|
+ self, seqs: Tensor, padding_mask: Optional[PaddingMask]
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
+ seqs, padding_mask = self.encoder_frontend(seqs, padding_mask)
|
|
|
return self.encoder(seqs, padding_mask) # type: ignore[no-any-return]
|
|
|
|
|
|
@finaloverride
|
|
|
def decode(
|
|
|
self,
|
|
|
seqs: Tensor,
|
|
|
- seq_lens: Optional[Tensor],
|
|
|
+ padding_mask: Optional[PaddingMask],
|
|
|
encoder_output: Tensor,
|
|
|
- encoder_padding_mask: Optional[Tensor],
|
|
|
+ encoder_padding_mask: Optional[PaddingMask],
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
- seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
+ seqs, padding_mask = self.decoder_frontend(
|
|
|
+ seqs, padding_mask, state_bag=state_bag
|
|
|
+ )
|
|
|
|
|
|
return self.decoder( # type: ignore[no-any-return]
|
|
|
seqs,
|
|
@@ -218,7 +221,7 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
|
|
|
@finaloverride
|
|
|
def project(
|
|
|
- self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
|
|
|
+ self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
|
|
|
) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
@@ -274,9 +277,9 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
def forward(
|
|
|
self,
|
|
|
text_decoder_output: Tensor,
|
|
|
- text_decoder_padding_mask: Optional[Tensor],
|
|
|
+ text_decoder_padding_mask: Optional[PaddingMask],
|
|
|
target_seqs: Tensor,
|
|
|
- target_seq_lens: Optional[Tensor],
|
|
|
+ target_padding_mask: Optional[PaddingMask],
|
|
|
) -> SequenceModelOutput:
|
|
|
encoder_output, encoder_padding_mask = self.encode(
|
|
|
text_decoder_output, text_decoder_padding_mask
|
|
@@ -284,7 +287,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
|
|
|
decoder_output, decoder_padding_mask = self.decode(
|
|
|
target_seqs,
|
|
|
- target_seq_lens,
|
|
|
+ target_padding_mask,
|
|
|
encoder_output,
|
|
|
encoder_padding_mask,
|
|
|
)
|
|
@@ -294,8 +297,8 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
def encode(
|
|
|
self,
|
|
|
text_decoder_output: Tensor,
|
|
|
- text_decoder_padding_mask: Optional[Tensor],
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
+ text_decoder_padding_mask: Optional[PaddingMask],
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
if self.encoder is None:
|
|
|
return text_decoder_output, text_decoder_padding_mask
|
|
|
|
|
@@ -304,12 +307,14 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
def decode(
|
|
|
self,
|
|
|
seqs: Tensor,
|
|
|
- seq_lens: Optional[Tensor],
|
|
|
+ padding_mask: Optional[PaddingMask],
|
|
|
encoder_output: Tensor,
|
|
|
- encoder_padding_mask: Optional[Tensor],
|
|
|
+ encoder_padding_mask: Optional[PaddingMask],
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
- seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
+ seqs, padding_mask = self.decoder_frontend(
|
|
|
+ seqs, padding_mask, state_bag=state_bag
|
|
|
+ )
|
|
|
|
|
|
return self.decoder( # type: ignore[no-any-return]
|
|
|
seqs,
|
|
@@ -320,7 +325,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
)
|
|
|
|
|
|
def project(
|
|
|
- self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
|
|
|
+ self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
|
|
|
) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
@@ -375,18 +380,18 @@ class UnitYNART2UModel(Module):
|
|
|
def forward(
|
|
|
self,
|
|
|
text_decoder_output: Tensor,
|
|
|
- text_decoder_padding_mask: Optional[Tensor],
|
|
|
+ text_decoder_padding_mask: Optional[PaddingMask],
|
|
|
target_seqs: Optional[Tensor],
|
|
|
- target_seq_lens: Optional[Tensor],
|
|
|
+ target_padding_mask: Optional[PaddingMask],
|
|
|
text_seqs: Optional[Tensor],
|
|
|
- ) -> Tuple[SequenceModelOutput, Optional[Tensor]]:
|
|
|
+ ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
|
|
|
encoder_output, encoder_padding_mask = self.encode(
|
|
|
text_decoder_output, text_decoder_padding_mask
|
|
|
)
|
|
|
|
|
|
decoder_output, decoder_padding_mask = self.decode(
|
|
|
target_seqs,
|
|
|
- target_seq_lens,
|
|
|
+ target_padding_mask,
|
|
|
encoder_output,
|
|
|
encoder_padding_mask,
|
|
|
text_seqs,
|
|
@@ -397,8 +402,8 @@ class UnitYNART2UModel(Module):
|
|
|
def encode(
|
|
|
self,
|
|
|
text_decoder_output: Tensor,
|
|
|
- text_decoder_padding_mask: Optional[Tensor],
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
+ text_decoder_padding_mask: Optional[PaddingMask],
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
if self.encoder is None:
|
|
|
return text_decoder_output, text_decoder_padding_mask
|
|
|
|
|
@@ -407,15 +412,15 @@ class UnitYNART2UModel(Module):
|
|
|
def decode(
|
|
|
self,
|
|
|
seqs: Optional[Tensor],
|
|
|
- seq_lens: Optional[Tensor],
|
|
|
+ padding_mask: Optional[PaddingMask],
|
|
|
encoder_output: Tensor,
|
|
|
- encoder_padding_mask: Optional[Tensor],
|
|
|
+ encoder_padding_mask: Optional[PaddingMask],
|
|
|
text_seqs: Optional[Tensor],
|
|
|
- ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
# encoder_output: (N, S, M)
|
|
|
# text_seqs: (N, S)
|
|
|
seqs, padding_mask = self.decoder_frontend(
|
|
|
- seqs, seq_lens, encoder_output, encoder_padding_mask, text_seqs
|
|
|
+ seqs, padding_mask, encoder_output, encoder_padding_mask, text_seqs
|
|
|
)
|
|
|
|
|
|
return self.decoder(seqs, padding_mask) # type: ignore[no-any-return]
|