Browse Source

Fix Conv1dBlock to use padding_mask.

Kaushik Ram Sadagopan 2 years ago
parent
commit
afbcd665fb

+ 0 - 1
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -10,7 +10,6 @@ from torch import Tensor
 from torch.nn import Dropout, Module, Parameter
 
 from fairseq2.data import VocabularyInfo
-from fairseq2.data.text import TextTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.embedding import Embedding
 from fairseq2.nn.normalization import LayerNorm

+ 20 - 8
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -14,10 +14,9 @@ from fairseq2.nn.transformer import (
     TransformerDecoderLayer,
     MultiheadAttention,
 )
-from fairseq2.typing import DataType, Device, finaloverride
-
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.transformer import create_default_layer_norm
+from fairseq2.nn.utils.mask import apply_padding_mask
 from fairseq2.nn.utils.module import check_model_dim
 from fairseq2.typing import DataType, Device, finaloverride
 
@@ -58,7 +57,7 @@ class Conv1dBlock(Module):
             inner_dim,
             kernel_size,
             stride=1,
-            padding=(kernel_size - 1) // 2,
+            padding="same",
             bias=bias,
             device=device,
             dtype=dtype,
@@ -71,28 +70,41 @@ class Conv1dBlock(Module):
             model_dim,
             kernel_size,
             stride=1,
-            padding=(kernel_size - 1) // 2,
+            padding="same",
             bias=bias,
             device=device,
             dtype=dtype,
         )
 
     @finaloverride
-    def forward(self, seqs: Tensor) -> Tensor:
+    def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+        # Ensure that we do not leak padded positions in the convolution layer.
+        seqs = apply_padding_mask(seqs, padding_mask)
+
         # (N, S, M) -> (N, M, S)
         seqs = seqs.transpose(1, 2)
 
         # (N, M, S) -> (N, inner_dim, S)
         seqs = self.conv1(seqs)
 
+        # (N, inner_dim, S) -> (N, S, inner_dim)
+        seqs = seqs.transpose(1, 2)
+
+        seqs = apply_padding_mask(seqs, padding_mask)
+
         seqs = self.activation(seqs)
 
+        # (N, S, inner_dim) -> (N, inner_dim, S)
+        seqs = seqs.transpose(1, 2)
+
         # (N, inner_dim, S) -> (N, M, S)
         seqs = self.conv2(seqs)
 
         # (N, M, S) -> (N, S, M)
         seqs = seqs.transpose(1, 2)
 
+        seqs = apply_padding_mask(seqs, padding_mask)
+
         return seqs
 
 
@@ -165,7 +177,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
     ) -> Tuple[Tensor, Optional[Tensor]]:
         seqs = self._forward_self_attn(seqs, padding_mask)
 
-        seqs = self._forward_conv1d(seqs)
+        seqs = self._forward_conv1d(seqs, padding_mask)
 
         return seqs, padding_mask
 
@@ -193,10 +205,10 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
 
         return seqs
 
-    def _forward_conv1d(self, seqs: Tensor) -> Tensor:
+    def _forward_conv1d(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
         residual = seqs
 
-        seqs = self.conv1d(seqs)
+        seqs = self.conv1d(seqs, padding_mask)
 
         if self.conv1d_dropout is not None:
             seqs = self.conv1d_dropout(seqs)