Browse Source

Start using the new PaddingMask type (#49)

* Start using the new PaddingMask type

* Fix

* Fix 2

* Fix 3
Can Balioglu 1 year ago
parent
commit
71d26b65ec

+ 11 - 6
scripts/m4t/finetune/trainer.py

@@ -91,14 +91,17 @@ class UnitYFinetuneWrapper(nn.Module):
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_s2t else dummy_context:  # type:ignore
             assert batch.speech_to_text.src_tokens is not None
+            seqs=batch.speech_to_text.src_tokens.to(self.device)
+            seq_lens=batch.speech_to_text.src_lengths.to(self.device)
             speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
-                seqs=batch.speech_to_text.src_tokens.to(self.device),
-                seq_lens=batch.speech_to_text.src_lengths.to(self.device),
+                seqs=seqs, padding_mask=PaddingMask(seq_lens, seqs.size(1))
             )
             assert batch.speech_to_text.prev_output_tokens is not None
+            seqs=batch.speech_to_text.prev_output_tokens.to(self.device)
+            seq_lens=batch.speech_to_text.target_lengths.to(self.device)
             text_decoder_out, text_decoder_padding_mask = self.model.decode(
-                seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
-                seq_lens=batch.speech_to_text.target_lengths.to(self.device),
+                seqs=seqs,
+                padding_mask=PaddingMask(seq_lens, seqs.size(1)),
                 encoder_output=speech_encoder_out,
                 encoder_padding_mask=speech_encoder_padding_mask,
             )
@@ -114,9 +117,11 @@ class UnitYFinetuneWrapper(nn.Module):
                 text_decoder_output=text_decoder_out,
                 text_decoder_padding_mask=text_decoder_padding_mask,
             )
+            seqs=batch.text_to_units.prev_output_tokens.to(self.device)
+            seq_lens=batch.text_to_units.target_lengths.to(self.device)
             unit_decoder_out, _ = self.model.t2u_model.decode(
-                seqs=batch.text_to_units.prev_output_tokens.to(self.device),
-                seq_lens=batch.text_to_units.target_lengths.to(self.device),
+                seqs=seqs,
+                padding_mask=PaddingMask(seq_lens, seqs.size(1)),
                 encoder_output=unit_encoder_out,
                 encoder_padding_mask=unit_encoder_padding_mask,
             )

+ 6 - 4
src/seamless_communication/models/inference/translator.py

@@ -9,12 +9,13 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
 import torch
 import torch.nn as nn
 from fairseq2.assets.card import AssetCard
-from fairseq2.data import Collater
+from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
 from fairseq2.memory import MemoryBlock
+from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 from enum import Enum, auto
@@ -100,7 +101,7 @@ class Translator(nn.Module):
         model: UnitYModel,
         text_tokenizer: TextTokenizer,
         unit_tokenizer: UnitTokenizer,
-        src: Dict[str, Tensor],
+        src: SequenceData,
         input_modality: Modality,
         output_modality: Modality,
         tgt_lang: str,
@@ -139,9 +140,10 @@ class Translator(nn.Module):
             text_opts=text_opts,
             unit_opts=unit_opts,
         )
+        seqs, padding_mask = get_seqs_and_padding_mask(src)
         return generator(
-            src["seqs"],
-            src["seq_lens"],
+            seqs,
+            padding_mask,
             input_modality.value,
             output_modality.value,
             ngram_filtering=ngram_filtering,

+ 4 - 3
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -3,6 +3,7 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
+from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
 from fairseq2.models.wav2vec2 import (
     Wav2Vec2EncoderConfig,
@@ -105,18 +106,18 @@ class Wav2Vec2LayerOutputModel(nn.Module):
         self.encoder = w2v2.encoder
 
     @torch.inference_mode()
-    def forward(self, batch: SequenceBatch, out_layer_idx: int):
+    def forward(self, batch: SequenceBatch, out_layer_idx: int) -> Tensor:
         """
         :param batch:
             The batch of sequences to process.
         """
-        seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.seq_lens)
+        seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask)
         w2v2_layer_output = None
 
         def layer_output_hook(
             layer_idx: int,
             layer_output: Tensor,
-            layer_padding_mask: Optional[Tensor],
+            layer_padding_mask: Optional[PaddingMask],
             num_layers: int,
         ) -> bool:
             nonlocal w2v2_layer_output

+ 23 - 22
src/seamless_communication/models/unity/adaptor_block.py

@@ -10,17 +10,18 @@ import torch
 from fairseq2.models.conformer import ConformerBlock
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.transformer import (
+    AttentionMask,
     EncoderLayerOutputHook,
     FeedForwardNetwork,
     LayerNormFactory,
     MultiheadAttention,
     TransformerEncoder,
     TransformerEncoderLayer,
-    create_default_layer_norm,
+    create_standard_layer_norm,
 )
-from fairseq2.nn.utils.mask import to_padding_mask
 from fairseq2.nn.utils.module import check_model_dim
 from fairseq2.typing import DataType, Device
 from overrides import final as finaloverride
@@ -66,7 +67,7 @@ class UnitYEncoderAdaptor(TransformerEncoder):
         super().__init__(model_dim)
 
         if layer_norm_factory is None:
-            layer_norm_factory = create_default_layer_norm
+            layer_norm_factory = create_standard_layer_norm
 
         self.inner = inner
 
@@ -99,10 +100,10 @@ class UnitYEncoderAdaptor(TransformerEncoder):
     def forward(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         *,
         layer_output_hook: Optional[EncoderLayerOutputHook] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs, padding_mask = self.inner(
             seqs, padding_mask, layer_output_hook=layer_output_hook
         )
@@ -185,7 +186,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         super().__init__(model_dim)
 
         if layer_norm_factory is None:
-            layer_norm_factory = create_default_layer_norm
+            layer_norm_factory = create_standard_layer_norm
 
         self.kernel_size = kernel_size
         self.stride = stride
@@ -240,9 +241,9 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
     def forward(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
-        self_attn_mask: Optional[Tensor] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+        padding_mask: Optional[PaddingMask],
+        self_attn_mask: Optional[AttentionMask] = None,
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs, padding_mask = self._forward_self_attn(seqs, padding_mask, self_attn_mask)
 
         seqs = self._forward_ffn(seqs)
@@ -252,9 +253,9 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
     def _forward_self_attn(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
-        self_attn_mask: Optional[Tensor],
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+        padding_mask: Optional[PaddingMask],
+        self_attn_mask: Optional[AttentionMask],
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         residual = self.residual_layer_norm(seqs)
 
         # Apply pooling to the residual to match the sequence length of the
@@ -292,9 +293,9 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
             seqs,
             padding_mask,
             keys=seqs,
+            key_padding_mask=padding_mask,
             values=seqs,
             attn_mask=self_attn_mask,
-            key_padding_mask=padding_mask,
         )
 
         if self.self_attn_dropout is not None:
@@ -366,7 +367,7 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
         super().__init__(block.model_dim)
 
         if layer_norm_factory is None:
-            layer_norm_factory = create_default_layer_norm
+            layer_norm_factory = create_standard_layer_norm
 
         self.kernel_size = kernel_size
         self.stride = stride
@@ -394,9 +395,9 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
     def forward(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
-        self_attn_mask: Optional[Tensor] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+        padding_mask: Optional[PaddingMask],
+        self_attn_mask: Optional[AttentionMask] = None,
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         if self.layer_norm is not None:
             seqs = self.layer_norm(seqs)
 
@@ -425,15 +426,15 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
 
 
 def _compute_new_padding_mask(
-    seqs: Tensor, padding_mask: Optional[Tensor], kernel_size: int, stride: int
-) -> Optional[Tensor]:
+    seqs: Tensor, padding_mask: Optional[PaddingMask], kernel_size: int, stride: int
+) -> Optional[PaddingMask]:
     if padding_mask is None:
         return padding_mask
 
     pad = kernel_size // 2
 
-    seq_lens = padding_mask.size(1) - torch.nan_to_num(padding_mask, neginf=1.0).sum(1)
+    seq_lens = ((padding_mask.seq_lens + 2 * pad - kernel_size) / stride) + 1
 
-    seq_lens = ((seq_lens + 2 * pad - kernel_size) / stride) + 1
+    seq_lens = seq_lens.floor().to(torch.int64)
 
-    return to_padding_mask(seqs, seq_lens.floor())
+    return PaddingMask(seq_lens, batch_seq_len=seqs.size(1))

+ 4 - 2
src/seamless_communication/models/unity/builder.py

@@ -236,6 +236,7 @@ class UnitYBuilder:
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
         mt_model_builder: NllbBuilder,
         t2u_builder: Optional["UnitYT2UBuilder"],
+        *,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
@@ -269,11 +270,12 @@ class UnitYBuilder:
             )
 
         self.config = config
+
         self.w2v2_encoder_builder = w2v2_encoder_builder
         self.mt_model_builder = mt_model_builder
         self.t2u_builder = t2u_builder
-        self.device = device
-        self.dtype = dtype
+
+        self.device, self.dtype = device, dtype
 
     def build_model(self) -> UnitYModel:
         """Build a model."""

+ 18 - 13
src/seamless_communication/models/unity/generator.py

@@ -26,6 +26,7 @@ from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenizer,
 )
 from fairseq2.nn.utils.module import infer_device
+from fairseq2.nn.padding import PaddingMask
 from torch import Tensor
 
 
@@ -149,7 +150,7 @@ class UnitYGenerator:
     def __call__(
         self,
         source_seqs: Tensor,
-        source_seq_lens: Optional[Tensor],
+        source_padding_mask: Optional[PaddingMask],
         input_modality: str = "speech",
         output_modality: str = "speech",
         ngram_filtering: bool = False,
@@ -160,10 +161,9 @@ class UnitYGenerator:
             where :math:`N` is the batch size, :math:`S` is the sequence length,
             and :math:`*` is any number of sequence-specific dimensions
             including none.
-        :param source_seq_lens:
-            An array where each element represents the length of the sequence at
-            the same index in ``source_seqs``. *Shape:* :math:`(N)`, where
-            :math:`N` is the batch size.
+        :param source_padding_mask:
+            The padding mask of ``source_seqs``. *Shape:* :math:`(N,S)`, where
+            :math:`N` is the batch size and :math:`S` is the sequence length.
         :param input_modality:
             The type of modality to encode.
         :param output_modality:
@@ -175,9 +175,9 @@ class UnitYGenerator:
         """
 
         if input_modality == "speech":
-            text_output = self.s2t_generator.generate_ex(source_seqs, source_seq_lens)
+            text_output = self.s2t_generator.generate_ex(source_seqs, source_padding_mask)
         elif input_modality == "text" and self.t2t_generator is not None:
-            text_output = self.t2t_generator.generate_ex(source_seqs, source_seq_lens)
+            text_output = self.t2t_generator.generate_ex(source_seqs, source_padding_mask)
         elif input_modality == "text" and self.t2t_generator is None:
             raise ValueError(
                 f"Please set use_text_encoder to True in your model config to encode text."
@@ -189,16 +189,18 @@ class UnitYGenerator:
         if output_modality == "text":
             return text_output, None
 
-        text_seqs, text_seq_lens = text_output.generator_output.collate()
+        text_seqs, text_padding_mask = text_output.generator_output.collate()
 
         # Manually trim the final EOS token to be consistent with fairseq.
-        if text_seq_lens is not None:
-            text_seq_lens -= 1
+        text_seqs = text_seqs[:, :-1]
+
+        if text_padding_mask is not None:
+            text_padding_mask = text_padding_mask.trim(1)
 
         # Use the output of the text generator to compute the decoder output.
         decoder_output, decoder_padding_mask = self.model.decode(
             text_seqs,
-            text_seq_lens,
+            text_padding_mask,
             text_output.encoder_output,
             text_output.encoder_padding_mask,
         )
@@ -223,13 +225,16 @@ class UnitYGenerator:
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
                 target_seqs=None,
-                target_seq_lens=None,
+                target_padding_mask=None,
                 text_seqs=text_seqs,
             )
             # (B, S_unit, V_unit)
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
             # Apply the padding mask to the generated units.
-            unit_seqs[decoder_padding_mask == -torch.inf] = unit_decoder_output.pad_idx
+            if decoder_padding_mask is not None:
+                m = decoder_padding_mask.materialize()
+
+                unit_seqs[m == -torch.inf] = unit_decoder_output.pad_idx
 
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)

+ 8 - 8
src/seamless_communication/models/unity/length_regulator.py

@@ -11,10 +11,10 @@ from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
 from typing import Optional, Tuple
 
 from fairseq2.typing import DataType, Device
-from fairseq2.nn.transformer import create_default_layer_norm
+from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.projection import Linear
-from fairseq2.nn.utils.mask import apply_padding_mask
 
 
 class HardUpsampling(Module):
@@ -75,7 +75,7 @@ class VariancePredictor(Module):
             ReLU(),
         )
 
-        layer_norm_factory = create_default_layer_norm
+        layer_norm_factory = create_standard_layer_norm
 
         self.ln1 = layer_norm_factory(var_pred_hidden_dim, device=device, dtype=dtype)
 
@@ -101,7 +101,7 @@ class VariancePredictor(Module):
             var_pred_hidden_dim, 1, bias=True, device=device, dtype=dtype
         )
 
-    def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+    def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
         # Ensure that we do not leak padded positions in the convolution layer.
         seqs = apply_padding_mask(seqs, padding_mask)
 
@@ -173,10 +173,10 @@ class VarianceAdaptor(Module):
     def forward(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         duration_factor: float = 1.0,
         min_duration: int = 0,
-    ) -> Tuple[Tensor, Tensor]:
+    ) -> Tuple[Tensor, PaddingMask]:
         log_durations = self.duration_predictor(seqs, padding_mask)
 
         durations = torch.clamp(
@@ -185,10 +185,10 @@ class VarianceAdaptor(Module):
         )
 
         # We need to apply the padding_mask again since we clamp by min_duration.
-        durations = apply_padding_mask(durations, padding_mask)
+        durations = apply_padding_mask(durations, padding_mask, fill_value=0)
 
         # TODO: Implement pitch, energy predictors.
         # TODO: Implement GaussianUpsampling.
         seqs, seq_lens = self.hard_upsampling(seqs, durations)
 
-        return seqs, seq_lens
+        return seqs, PaddingMask(seq_lens, batch_seq_len=seqs.size(1))

+ 48 - 43
src/seamless_communication/models/unity/model.py

@@ -11,6 +11,7 @@ from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.transformer.frontend import TransformerFrontend
 from fairseq2.nn.incremental_state import IncrementalStateBag
+from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.projection import Projection
 from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
 from fairseq2.nn.utils.module import check_model_dim
@@ -97,28 +98,28 @@ class UnitYModel(EncoderDecoderModel):
 
     @finaloverride
     def encode(
-        self, seqs: Tensor, seq_lens: Optional[Tensor]
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         if self.input_modality == "speech":
-            return self.encode_speech(seqs, seq_lens)
+            return self.encode_speech(seqs, padding_mask)
 
         if self.input_modality == "text":
-            return self.encode_text(seqs, seq_lens)
+            return self.encode_text(seqs, padding_mask)
 
         raise RuntimeError(
             f"`input_modality` must be 'speech' or 'text', but is '{self.input_modality}' instead."
         )
 
     def encode_speech(
-        self, seqs: Tensor, seq_lens: Optional[Tensor]
-    ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.speech_encoder_frontend(seqs, seq_lens)
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+        seqs, padding_mask = self.speech_encoder_frontend(seqs, padding_mask)
 
         return self.speech_encoder(seqs, padding_mask)  # type: ignore[no-any-return]
 
     def encode_text(
-        self, seqs: Tensor, seq_lens: Optional[Tensor]
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         if self.text_encoder is None:
             raise ValueError(
                 "`encode_text()` requires a text encoder, but the current UnitY model does not have one."
@@ -126,7 +127,7 @@ class UnitYModel(EncoderDecoderModel):
 
         assert self.text_encoder_frontend is not None
 
-        seqs, padding_mask = self.text_encoder_frontend(seqs, seq_lens)
+        seqs, padding_mask = self.text_encoder_frontend(seqs, padding_mask)
 
         return self.text_encoder(seqs, padding_mask)  # type: ignore[no-any-return]
 
@@ -134,13 +135,13 @@ class UnitYModel(EncoderDecoderModel):
     def decode(
         self,
         seqs: Tensor,
-        seq_lens: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
-        encoder_padding_mask: Optional[Tensor],
+        encoder_padding_mask: Optional[PaddingMask],
         state_bag: Optional[IncrementalStateBag] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs, padding_mask = self.text_decoder_frontend(
-            seqs, seq_lens, state_bag=state_bag
+            seqs, padding_mask, state_bag=state_bag
         )
 
         return self.text_decoder(  # type: ignore[no-any-return]
@@ -153,7 +154,7 @@ class UnitYModel(EncoderDecoderModel):
 
     @finaloverride
     def project(
-        self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
+        self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
@@ -192,21 +193,23 @@ class UnitYX2TModel(EncoderDecoderModel):
 
     @finaloverride
     def encode(
-        self, seqs: Tensor, seq_lens: Optional[Tensor]
-    ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.encoder_frontend(seqs, seq_lens)
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+        seqs, padding_mask = self.encoder_frontend(seqs, padding_mask)
         return self.encoder(seqs, padding_mask)  # type: ignore[no-any-return]
 
     @finaloverride
     def decode(
         self,
         seqs: Tensor,
-        seq_lens: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
-        encoder_padding_mask: Optional[Tensor],
+        encoder_padding_mask: Optional[PaddingMask],
         state_bag: Optional[IncrementalStateBag] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+        seqs, padding_mask = self.decoder_frontend(
+            seqs, padding_mask, state_bag=state_bag
+        )
 
         return self.decoder(  # type: ignore[no-any-return]
             seqs,
@@ -218,7 +221,7 @@ class UnitYX2TModel(EncoderDecoderModel):
 
     @finaloverride
     def project(
-        self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
+        self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
@@ -274,9 +277,9 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
     def forward(
         self,
         text_decoder_output: Tensor,
-        text_decoder_padding_mask: Optional[Tensor],
+        text_decoder_padding_mask: Optional[PaddingMask],
         target_seqs: Tensor,
-        target_seq_lens: Optional[Tensor],
+        target_padding_mask: Optional[PaddingMask],
     ) -> SequenceModelOutput:
         encoder_output, encoder_padding_mask = self.encode(
             text_decoder_output, text_decoder_padding_mask
@@ -284,7 +287,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 
         decoder_output, decoder_padding_mask = self.decode(
             target_seqs,
-            target_seq_lens,
+            target_padding_mask,
             encoder_output,
             encoder_padding_mask,
         )
@@ -294,8 +297,8 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
     def encode(
         self,
         text_decoder_output: Tensor,
-        text_decoder_padding_mask: Optional[Tensor],
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+        text_decoder_padding_mask: Optional[PaddingMask],
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         if self.encoder is None:
             return text_decoder_output, text_decoder_padding_mask
 
@@ -304,12 +307,14 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
     def decode(
         self,
         seqs: Tensor,
-        seq_lens: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
-        encoder_padding_mask: Optional[Tensor],
+        encoder_padding_mask: Optional[PaddingMask],
         state_bag: Optional[IncrementalStateBag] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+        seqs, padding_mask = self.decoder_frontend(
+            seqs, padding_mask, state_bag=state_bag
+        )
 
         return self.decoder(  # type: ignore[no-any-return]
             seqs,
@@ -320,7 +325,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
         )
 
     def project(
-        self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
+        self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
@@ -375,18 +380,18 @@ class UnitYNART2UModel(Module):
     def forward(
         self,
         text_decoder_output: Tensor,
-        text_decoder_padding_mask: Optional[Tensor],
+        text_decoder_padding_mask: Optional[PaddingMask],
         target_seqs: Optional[Tensor],
-        target_seq_lens: Optional[Tensor],
+        target_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
-    ) -> Tuple[SequenceModelOutput, Optional[Tensor]]:
+    ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
         encoder_output, encoder_padding_mask = self.encode(
             text_decoder_output, text_decoder_padding_mask
         )
 
         decoder_output, decoder_padding_mask = self.decode(
             target_seqs,
-            target_seq_lens,
+            target_padding_mask,
             encoder_output,
             encoder_padding_mask,
             text_seqs,
@@ -397,8 +402,8 @@ class UnitYNART2UModel(Module):
     def encode(
         self,
         text_decoder_output: Tensor,
-        text_decoder_padding_mask: Optional[Tensor],
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+        text_decoder_padding_mask: Optional[PaddingMask],
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         if self.encoder is None:
             return text_decoder_output, text_decoder_padding_mask
 
@@ -407,15 +412,15 @@ class UnitYNART2UModel(Module):
     def decode(
         self,
         seqs: Optional[Tensor],
-        seq_lens: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
-        encoder_padding_mask: Optional[Tensor],
+        encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         # encoder_output: (N, S, M)
         # text_seqs: (N, S)
         seqs, padding_mask = self.decoder_frontend(
-            seqs, seq_lens, encoder_output, encoder_padding_mask, text_seqs
+            seqs, padding_mask, encoder_output, encoder_padding_mask, text_seqs
         )
 
         return self.decoder(seqs, padding_mask)  # type: ignore[no-any-return]

+ 14 - 14
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -13,9 +13,9 @@ from fairseq2.data import VocabularyInfo
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.embedding import Embedding
 from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder
-from fairseq2.nn.transformer import create_default_layer_norm
-from fairseq2.nn.utils.mask import to_padding_mask
+from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
 
 
@@ -120,7 +120,7 @@ class NARDecoderFrontend(Module):
         self.variance_adaptor = variance_adaptor
 
         if layer_norm:
-            self.layer_norm = create_default_layer_norm(
+            self.layer_norm = create_standard_layer_norm(
                 self.model_dim, device=device, dtype=dtype
             )
         else:
@@ -265,7 +265,7 @@ class NARDecoderFrontend(Module):
     def character_level_upsampling(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         char_seqs: Tensor,
         char_lens: Tensor,
     ) -> Tensor:
@@ -287,7 +287,7 @@ class NARDecoderFrontend(Module):
         return seqs
 
     def forward_unit_pos_embedding(
-        self, seqs: Tensor, padding_mask: Optional[Tensor]
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
     ) -> Tensor:
         pos_embeds = self.pos_emb_alpha * (
             self.unit_pos_encoder(seqs, padding_mask) - seqs
@@ -304,18 +304,20 @@ class NARDecoderFrontend(Module):
     def forward(
         self,
         target_seqs: Optional[Tensor],
-        target_seq_lens: Optional[Tensor],
+        target_padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
-        encoder_padding_mask: Optional[Tensor],
+        encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         assert text_seqs is not None
 
         # text_seqs: (N, S_text)
         char_seqs, char_seq_lens, char_lens = self.text_to_char_seqs(text_seqs)
 
         # char_seqs: (N, S_char)
-        encoder_padding_mask = to_padding_mask(char_seqs, char_seq_lens)
+        encoder_padding_mask = PaddingMask(
+            char_seq_lens, batch_seq_len=char_seqs.size(1)
+        )
 
         # (N, S_text, M) -> (N, S_char, M)
         seqs = self.character_level_upsampling(
@@ -323,14 +325,12 @@ class NARDecoderFrontend(Module):
         )
 
         # (N, S_char, M) -> (N, S_unit, M)
-        seqs, seq_lens = self.variance_adaptor(
+        seqs, padding_mask = self.variance_adaptor(
             seqs,
             encoder_padding_mask,
             min_duration=1,
         )
 
-        decoder_padding_mask = to_padding_mask(seqs, seq_lens)
-
-        seqs = self.forward_unit_pos_embedding(seqs, decoder_padding_mask)
+        seqs = self.forward_unit_pos_embedding(seqs, padding_mask)
 
-        return seqs, decoder_padding_mask
+        return seqs, padding_mask

+ 12 - 11
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -11,12 +11,13 @@ from torch.nn import Conv1d, Dropout, Module, ReLU
 
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.transformer import (
+    AttentionMask,
     TransformerDecoderLayer,
     MultiheadAttention,
 )
 from fairseq2.nn.incremental_state import IncrementalStateBag
-from fairseq2.nn.transformer import create_default_layer_norm
-from fairseq2.nn.utils.mask import apply_padding_mask
+from fairseq2.nn.padding import PaddingMask, apply_padding_mask
+from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.nn.utils.module import check_model_dim
 from fairseq2.typing import DataType, Device, finaloverride
 
@@ -77,7 +78,7 @@ class Conv1dBlock(Module):
         )
 
     @finaloverride
-    def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+    def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
         # Ensure that we do not leak padded positions in the convolution layer.
         seqs = apply_padding_mask(seqs, padding_mask)
 
@@ -148,7 +149,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
         else:
             self.register_module("self_attn_dropout", None)
 
-        layer_norm_factory = create_default_layer_norm
+        layer_norm_factory = create_standard_layer_norm
 
         self.self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
@@ -167,12 +168,12 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
     def forward(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
-        self_attn_mask: Optional[Tensor] = None,
+        padding_mask: Optional[PaddingMask],
+        self_attn_mask: Optional[AttentionMask] = None,
         encoder_output: Optional[Tensor] = None,
-        encoder_padding_mask: Optional[Tensor] = None,
+        encoder_padding_mask: Optional[PaddingMask] = None,
         state_bag: Optional[IncrementalStateBag] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs = self._forward_self_attn(seqs, padding_mask)
 
         seqs = self._forward_conv1d(seqs, padding_mask)
@@ -182,7 +183,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
     def _forward_self_attn(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
     ) -> Tensor:
         residual = seqs
 
@@ -190,8 +191,8 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
             seqs,
             padding_mask,
             keys=seqs,
-            values=seqs,
             key_padding_mask=padding_mask,
+            values=seqs,
         )
 
         if self.self_attn_dropout is not None:
@@ -203,7 +204,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
 
         return seqs
 
-    def _forward_conv1d(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+    def _forward_conv1d(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
         residual = seqs
 
         seqs = self.conv1d(seqs, padding_mask)

+ 4 - 3
src/seamless_communication/models/unity/t2u_builder.py

@@ -219,6 +219,7 @@ class UnitYT2UBuilder:
     def __init__(
         self,
         config: UnitYT2UConfig,
+        *,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
@@ -231,8 +232,8 @@ class UnitYT2UBuilder:
             The data type of module parameters and buffers.
         """
         self.config = config
-        self.device = device
-        self.dtype = dtype
+
+        self.device, self.dtype = device, dtype
 
     def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]:
         """Build a model."""
@@ -490,4 +491,4 @@ def create_unity_t2u_model(
     :param dtype:
         The data type of module parameters and buffers.
     """
-    return UnitYT2UBuilder(config, device, dtype).build_model()
+    return UnitYT2UBuilder(config, device=device, dtype=dtype).build_model()

+ 4 - 3
src/seamless_communication/models/vocoder/builder.py

@@ -80,6 +80,7 @@ class VocoderBuilder:
     def __init__(
         self,
         config: VocoderConfig,
+        *,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
@@ -92,8 +93,8 @@ class VocoderBuilder:
             The data type of module parameters and buffers.
         """
         self.config = config
-        self.device = device
-        self.dtype = dtype
+
+        self.device, self.dtype = device, dtype
 
     def build_model(self) -> Vocoder:
         """Build a model."""
@@ -133,4 +134,4 @@ def create_vocoder_model(
         The data type of module parameters and buffers.
     """
 
-    return VocoderBuilder(config, device, dtype).build_model()
+    return VocoderBuilder(config, device=device, dtype=dtype).build_model()

+ 6 - 6
src/seamless_communication/models/wav2vec2_chunk/chunk_attention_mask.py

@@ -4,13 +4,16 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+from typing import Optional
+
 import torch
 from torch import Tensor
 
 from fairseq2.nn.utils.mask import to_float_mask
+from fairseq2.nn.transformer import AttentionMask, CustomAttentionMask
 
 
-class ChunkAttentionMaskGenerator:
+class ChunkAttentionMaskFactory:
     """Generates a chunk attention mask for self attention.
 
     .. note::
@@ -27,7 +30,7 @@ class ChunkAttentionMaskGenerator:
         if self.right_chunk_num != 0:
             raise ValueError("We currently only support `right_chunk_num` == 0.")
 
-    def __call__(self, seqs: Tensor) -> Tensor:
+    def __call__(self, seqs: Tensor) -> Optional[AttentionMask]:
         """
         :param seqs:
             The sequences for which to generate the mask. *Shape:*
@@ -71,7 +74,4 @@ class ChunkAttentionMaskGenerator:
 
         mask = mask[:seq_len, :seq_len]
 
-        return mask
-
-    def __repr__(self) -> str:
-        return "ChunkAttentionMaskGenerator"
+        return CustomAttentionMask(mask)

+ 8 - 7
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -12,16 +12,17 @@ from torch.nn import Dropout
 from fairseq2.nn.utils.module import check_model_dim
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.padding import PaddingMask
 
 from fairseq2.nn.transformer import (
-    AttentionMaskGenerator,
+    AttentionMaskFactory,
     EncoderLayerOutputHook,
     TransformerEncoder,
     TransformerEncoderLayer,
 )
 
 from seamless_communication.models.wav2vec2_chunk.chunk_attention_mask import (
-    ChunkAttentionMaskGenerator,
+    ChunkAttentionMaskFactory,
 )
 
 from fairseq2.typing import finaloverride
@@ -32,7 +33,7 @@ class ChunkTransformerEncoder(TransformerEncoder):
     """Represents a Chunk Transformer encoder."""
 
     preliminary_dropout: Optional[Dropout]
-    self_attn_mask_gen: AttentionMaskGenerator
+    self_attn_mask_factory: ChunkAttentionMaskFactory
     layers: ModuleList
     layer_norm: Optional[LayerNorm]
 
@@ -74,7 +75,7 @@ class ChunkTransformerEncoder(TransformerEncoder):
         else:
             self.register_module("preliminary_dropout", None)
 
-        self.self_attn_mask_gen = ChunkAttentionMaskGenerator(
+        self.self_attn_mask_factory = ChunkAttentionMaskFactory(
             chunk_size * 2, left_chunk_num, right_chunk_num
         )
 
@@ -86,17 +87,17 @@ class ChunkTransformerEncoder(TransformerEncoder):
     def forward(
         self,
         seqs: Tensor,
-        padding_mask: Optional[Tensor],
+        padding_mask: Optional[PaddingMask],
         *,
         layer_output_hook: Optional[EncoderLayerOutputHook] = None,
-    ) -> Tuple[Tensor, Optional[Tensor]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
         if layer_output_hook is not None and self.layers.drop_p > 0.0:
             raise ValueError("`layer_hook` must be `None` when LayerDrop is enabled.")
 
         if self.preliminary_dropout is not None:
             seqs = self.preliminary_dropout(seqs)
 
-        self_attn_mask = self.self_attn_mask_gen(seqs)
+        self_attn_mask = self.self_attn_mask_factory(seqs)
 
         num_layers = len(self.layers)