|  | @@ -11,6 +11,7 @@ from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
 | 
	
		
			
				|  |  |  from fairseq2.models.sequence import SequenceModelOutput
 | 
	
		
			
				|  |  |  from fairseq2.models.transformer.frontend import TransformerFrontend
 | 
	
		
			
				|  |  |  from fairseq2.nn.incremental_state import IncrementalStateBag
 | 
	
		
			
				|  |  | +from fairseq2.nn.padding import PaddingMask
 | 
	
		
			
				|  |  |  from fairseq2.nn.projection import Projection
 | 
	
		
			
				|  |  |  from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
 | 
	
		
			
				|  |  |  from fairseq2.nn.utils.module import check_model_dim
 | 
	
	
		
			
				|  | @@ -97,28 +98,28 @@ class UnitYModel(EncoderDecoderModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @finaloverride
 | 
	
		
			
				|  |  |      def encode(
 | 
	
		
			
				|  |  | -        self, seqs: Tensor, seq_lens: Optional[Tensor]
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | +        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  |          if self.input_modality == "speech":
 | 
	
		
			
				|  |  | -            return self.encode_speech(seqs, seq_lens)
 | 
	
		
			
				|  |  | +            return self.encode_speech(seqs, padding_mask)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          if self.input_modality == "text":
 | 
	
		
			
				|  |  | -            return self.encode_text(seqs, seq_lens)
 | 
	
		
			
				|  |  | +            return self.encode_text(seqs, padding_mask)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          raise RuntimeError(
 | 
	
		
			
				|  |  |              f"`input_modality` must be 'speech' or 'text', but is '{self.input_modality}' instead."
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def encode_speech(
 | 
	
		
			
				|  |  | -        self, seqs: Tensor, seq_lens: Optional[Tensor]
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | -        seqs, padding_mask = self.speech_encoder_frontend(seqs, seq_lens)
 | 
	
		
			
				|  |  | +        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  | +        seqs, padding_mask = self.speech_encoder_frontend(seqs, padding_mask)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return self.speech_encoder(seqs, padding_mask)  # type: ignore[no-any-return]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def encode_text(
 | 
	
		
			
				|  |  | -        self, seqs: Tensor, seq_lens: Optional[Tensor]
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | +        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  |          if self.text_encoder is None:
 | 
	
		
			
				|  |  |              raise ValueError(
 | 
	
		
			
				|  |  |                  "`encode_text()` requires a text encoder, but the current UnitY model does not have one."
 | 
	
	
		
			
				|  | @@ -126,7 +127,7 @@ class UnitYModel(EncoderDecoderModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assert self.text_encoder_frontend is not None
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        seqs, padding_mask = self.text_encoder_frontend(seqs, seq_lens)
 | 
	
		
			
				|  |  | +        seqs, padding_mask = self.text_encoder_frontend(seqs, padding_mask)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return self.text_encoder(seqs, padding_mask)  # type: ignore[no-any-return]
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -134,13 +135,13 @@ class UnitYModel(EncoderDecoderModel):
 | 
	
		
			
				|  |  |      def decode(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          seqs: Tensor,
 | 
	
		
			
				|  |  | -        seq_lens: Optional[Tensor],
 | 
	
		
			
				|  |  | +        padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          encoder_output: Tensor,
 | 
	
		
			
				|  |  | -        encoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | +        encoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          state_bag: Optional[IncrementalStateBag] = None,
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  |          seqs, padding_mask = self.text_decoder_frontend(
 | 
	
		
			
				|  |  | -            seqs, seq_lens, state_bag=state_bag
 | 
	
		
			
				|  |  | +            seqs, padding_mask, state_bag=state_bag
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return self.text_decoder(  # type: ignore[no-any-return]
 | 
	
	
		
			
				|  | @@ -153,7 +154,7 @@ class UnitYModel(EncoderDecoderModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @finaloverride
 | 
	
		
			
				|  |  |      def project(
 | 
	
		
			
				|  |  | -        self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
 | 
	
		
			
				|  |  | +        self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
 | 
	
		
			
				|  |  |      ) -> SequenceModelOutput:
 | 
	
		
			
				|  |  |          logits = self.final_proj(decoder_output)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -192,21 +193,23 @@ class UnitYX2TModel(EncoderDecoderModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @finaloverride
 | 
	
		
			
				|  |  |      def encode(
 | 
	
		
			
				|  |  | -        self, seqs: Tensor, seq_lens: Optional[Tensor]
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | -        seqs, padding_mask = self.encoder_frontend(seqs, seq_lens)
 | 
	
		
			
				|  |  | +        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  | +        seqs, padding_mask = self.encoder_frontend(seqs, padding_mask)
 | 
	
		
			
				|  |  |          return self.encoder(seqs, padding_mask)  # type: ignore[no-any-return]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @finaloverride
 | 
	
		
			
				|  |  |      def decode(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          seqs: Tensor,
 | 
	
		
			
				|  |  | -        seq_lens: Optional[Tensor],
 | 
	
		
			
				|  |  | +        padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          encoder_output: Tensor,
 | 
	
		
			
				|  |  | -        encoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | +        encoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          state_bag: Optional[IncrementalStateBag] = None,
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | -        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  | +        seqs, padding_mask = self.decoder_frontend(
 | 
	
		
			
				|  |  | +            seqs, padding_mask, state_bag=state_bag
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return self.decoder(  # type: ignore[no-any-return]
 | 
	
		
			
				|  |  |              seqs,
 | 
	
	
		
			
				|  | @@ -218,7 +221,7 @@ class UnitYX2TModel(EncoderDecoderModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @finaloverride
 | 
	
		
			
				|  |  |      def project(
 | 
	
		
			
				|  |  | -        self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
 | 
	
		
			
				|  |  | +        self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
 | 
	
		
			
				|  |  |      ) -> SequenceModelOutput:
 | 
	
		
			
				|  |  |          logits = self.final_proj(decoder_output)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -274,9 +277,9 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 | 
	
		
			
				|  |  |      def forward(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          text_decoder_output: Tensor,
 | 
	
		
			
				|  |  | -        text_decoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | +        text_decoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          target_seqs: Tensor,
 | 
	
		
			
				|  |  | -        target_seq_lens: Optional[Tensor],
 | 
	
		
			
				|  |  | +        target_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |      ) -> SequenceModelOutput:
 | 
	
		
			
				|  |  |          encoder_output, encoder_padding_mask = self.encode(
 | 
	
		
			
				|  |  |              text_decoder_output, text_decoder_padding_mask
 | 
	
	
		
			
				|  | @@ -284,7 +287,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          decoder_output, decoder_padding_mask = self.decode(
 | 
	
		
			
				|  |  |              target_seqs,
 | 
	
		
			
				|  |  | -            target_seq_lens,
 | 
	
		
			
				|  |  | +            target_padding_mask,
 | 
	
		
			
				|  |  |              encoder_output,
 | 
	
		
			
				|  |  |              encoder_padding_mask,
 | 
	
		
			
				|  |  |          )
 | 
	
	
		
			
				|  | @@ -294,8 +297,8 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 | 
	
		
			
				|  |  |      def encode(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          text_decoder_output: Tensor,
 | 
	
		
			
				|  |  | -        text_decoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | +        text_decoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  |          if self.encoder is None:
 | 
	
		
			
				|  |  |              return text_decoder_output, text_decoder_padding_mask
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -304,12 +307,14 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 | 
	
		
			
				|  |  |      def decode(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          seqs: Tensor,
 | 
	
		
			
				|  |  | -        seq_lens: Optional[Tensor],
 | 
	
		
			
				|  |  | +        padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          encoder_output: Tensor,
 | 
	
		
			
				|  |  | -        encoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | +        encoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          state_bag: Optional[IncrementalStateBag] = None,
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | -        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  | +        seqs, padding_mask = self.decoder_frontend(
 | 
	
		
			
				|  |  | +            seqs, padding_mask, state_bag=state_bag
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return self.decoder(  # type: ignore[no-any-return]
 | 
	
		
			
				|  |  |              seqs,
 | 
	
	
		
			
				|  | @@ -320,7 +325,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def project(
 | 
	
		
			
				|  |  | -        self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
 | 
	
		
			
				|  |  | +        self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
 | 
	
		
			
				|  |  |      ) -> SequenceModelOutput:
 | 
	
		
			
				|  |  |          logits = self.final_proj(decoder_output)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -375,18 +380,18 @@ class UnitYNART2UModel(Module):
 | 
	
		
			
				|  |  |      def forward(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          text_decoder_output: Tensor,
 | 
	
		
			
				|  |  | -        text_decoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | +        text_decoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          target_seqs: Optional[Tensor],
 | 
	
		
			
				|  |  | -        target_seq_lens: Optional[Tensor],
 | 
	
		
			
				|  |  | +        target_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          text_seqs: Optional[Tensor],
 | 
	
		
			
				|  |  | -    ) -> Tuple[SequenceModelOutput, Optional[Tensor]]:
 | 
	
		
			
				|  |  | +    ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  |          encoder_output, encoder_padding_mask = self.encode(
 | 
	
		
			
				|  |  |              text_decoder_output, text_decoder_padding_mask
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          decoder_output, decoder_padding_mask = self.decode(
 | 
	
		
			
				|  |  |              target_seqs,
 | 
	
		
			
				|  |  | -            target_seq_lens,
 | 
	
		
			
				|  |  | +            target_padding_mask,
 | 
	
		
			
				|  |  |              encoder_output,
 | 
	
		
			
				|  |  |              encoder_padding_mask,
 | 
	
		
			
				|  |  |              text_seqs,
 | 
	
	
		
			
				|  | @@ -397,8 +402,8 @@ class UnitYNART2UModel(Module):
 | 
	
		
			
				|  |  |      def encode(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          text_decoder_output: Tensor,
 | 
	
		
			
				|  |  | -        text_decoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | +        text_decoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  |          if self.encoder is None:
 | 
	
		
			
				|  |  |              return text_decoder_output, text_decoder_padding_mask
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -407,15 +412,15 @@ class UnitYNART2UModel(Module):
 | 
	
		
			
				|  |  |      def decode(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          seqs: Optional[Tensor],
 | 
	
		
			
				|  |  | -        seq_lens: Optional[Tensor],
 | 
	
		
			
				|  |  | +        padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          encoder_output: Tensor,
 | 
	
		
			
				|  |  | -        encoder_padding_mask: Optional[Tensor],
 | 
	
		
			
				|  |  | +        encoder_padding_mask: Optional[PaddingMask],
 | 
	
		
			
				|  |  |          text_seqs: Optional[Tensor],
 | 
	
		
			
				|  |  | -    ) -> Tuple[Tensor, Optional[Tensor]]:
 | 
	
		
			
				|  |  | +    ) -> Tuple[Tensor, Optional[PaddingMask]]:
 | 
	
		
			
				|  |  |          # encoder_output: (N, S, M)
 | 
	
		
			
				|  |  |          # text_seqs: (N, S)
 | 
	
		
			
				|  |  |          seqs, padding_mask = self.decoder_frontend(
 | 
	
		
			
				|  |  | -            seqs, seq_lens, encoder_output, encoder_padding_mask, text_seqs
 | 
	
		
			
				|  |  | +            seqs, padding_mask, encoder_output, encoder_padding_mask, text_seqs
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return self.decoder(seqs, padding_mask)  # type: ignore[no-any-return]
 |