Browse Source

Disintegrate UnitYNART2UModel's decoder and decoder_layer from TransformerDecoder, TransformerDecoderLayer. (#71)

Kaushik Ram Sadagopan 1 năm trước cách đây
mục cha
commit
45974cd2f4

+ 3 - 0
src/seamless_communication/models/unity/__init__.py

@@ -49,6 +49,9 @@ from seamless_communication.models.unity.model import UnitYOutput as UnitYOutput
 from seamless_communication.models.unity.nar_decoder_frontend import (
     NARDecoderFrontend as NARDecoderFrontend,
 )
+from seamless_communication.models.unity.nar_decoder import (
+    NARTransformerDecoder as NARTransformerDecoder,
+)
 from seamless_communication.models.unity.nar_decoder_layer import (
     NARTransformerDecoderLayer as NARTransformerDecoderLayer,
 )

+ 0 - 2
src/seamless_communication/models/unity/generator.py

@@ -231,8 +231,6 @@ class UnitYGenerator:
             unit_decoder_output, decoder_padding_mask = self.model.t2u_model(
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
-                target_seqs=None,
-                target_padding_mask=None,
                 text_seqs=text_seqs,
             )
             # (B, S_unit, V_unit)

+ 4 - 9
src/seamless_communication/models/unity/model.py

@@ -18,6 +18,7 @@ from overrides import final as finaloverride
 from torch import Tensor
 from torch.nn import Module
 
+from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 
 
@@ -335,7 +336,7 @@ class UnitYNART2UModel(Module):
     model_dim: int
     encoder: Optional[TransformerEncoder]
     decoder_frontend: NARDecoderFrontend
-    decoder: TransformerDecoder
+    decoder: NARTransformerDecoder
     final_proj: Projection
     pad_idx: Optional[int]
 
@@ -343,7 +344,7 @@ class UnitYNART2UModel(Module):
         self,
         encoder: Optional[TransformerEncoder],
         decoder_frontend: NARDecoderFrontend,
-        decoder: TransformerDecoder,
+        decoder: NARTransformerDecoder,
         final_proj: Projection,
         pad_idx: Optional[int],
     ) -> None:
@@ -377,8 +378,6 @@ class UnitYNART2UModel(Module):
         self,
         text_decoder_output: Tensor,
         text_decoder_padding_mask: Optional[PaddingMask],
-        target_seqs: Optional[Tensor],
-        target_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
     ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
         encoder_output, encoder_padding_mask = self.encode(
@@ -386,8 +385,6 @@ class UnitYNART2UModel(Module):
         )
 
         decoder_output, decoder_padding_mask = self.decode(
-            target_seqs,
-            target_padding_mask,
             encoder_output,
             encoder_padding_mask,
             text_seqs,
@@ -407,8 +404,6 @@ class UnitYNART2UModel(Module):
 
     def decode(
         self,
-        seqs: Optional[Tensor],
-        padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
@@ -416,7 +411,7 @@ class UnitYNART2UModel(Module):
         # encoder_output: (N, S, M)
         # text_seqs: (N, S)
         seqs, padding_mask = self.decoder_frontend(
-            seqs, padding_mask, encoder_output, encoder_padding_mask, text_seqs
+            encoder_output, encoder_padding_mask, text_seqs
         )
 
         return self.decoder(seqs, padding_mask)  # type: ignore[no-any-return]

+ 85 - 0
src/seamless_communication/models/unity/nar_decoder.py

@@ -0,0 +1,85 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Iterable, Optional, Tuple, final
+
+from torch import Tensor
+from torch.nn import Module
+
+from fairseq2.nn.module_list import ModuleList
+from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.transformer import (
+    TransformerNormOrder,
+    create_standard_layer_norm,
+)
+from fairseq2.typing import DataType, Device, finaloverride
+from seamless_communication.models.unity.nar_decoder_layer import (
+    NARTransformerDecoderLayer,
+)
+
+
+@final
+class NARTransformerDecoder(Module):
+    """Represents a non-autoregressive Transformer decoder."""
+
+    model_dim: int
+    layer_norm: Optional[LayerNorm]
+    norm_order: TransformerNormOrder
+
+    def __init__(
+        self,
+        layers: Iterable[NARTransformerDecoderLayer],
+        *,
+        norm_order: TransformerNormOrder = TransformerNormOrder.POST,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param layers:
+            The decoder layers.
+        :param norm_order:
+            The Layer Normalization order to use.
+        """
+        super().__init__()
+
+        layer_list = ModuleList(layers)
+
+        if not layer_list:
+            raise ValueError("`layers` must be non-empty.")
+
+        self.model_dim = layer_list[0].model_dim
+
+        self.layers = layer_list
+
+        if norm_order != TransformerNormOrder.POST:
+            self.layer_norm = create_standard_layer_norm(
+                self.model_dim, device=device, dtype=dtype
+            )
+        else:
+            self.register_module("layer_norm", None)
+
+        self.norm_order = norm_order
+
+    @finaloverride
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+        for layer in self.layers.drop_iter():
+            seqs, padding_mask = layer(seqs, padding_mask)
+
+        if self.layer_norm is not None:
+            seqs = self.layer_norm(seqs)
+
+        return seqs, padding_mask
+
+    def extra_repr(self) -> str:
+        """:meta private:"""
+        s = super().extra_repr()
+
+        return f"{s}, norm_order={self.norm_order}"

+ 0 - 2
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -303,8 +303,6 @@ class NARDecoderFrontend(Module):
     @finaloverride
     def forward(
         self,
-        target_seqs: Optional[Tensor],
-        target_padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],

+ 7 - 16
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -10,14 +10,8 @@ from torch import Tensor
 from torch.nn import Conv1d, Dropout, Module, ReLU
 
 from fairseq2.nn.normalization import LayerNorm
-from fairseq2.nn.transformer import (
-    AttentionMask,
-    TransformerDecoderLayer,
-    MultiheadAttention,
-)
-from fairseq2.nn.incremental_state import IncrementalStateBag
+from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
-from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.typing import DataType, Device, finaloverride
 
 
@@ -107,10 +101,11 @@ class Conv1dBlock(Module):
 
 
 @final
-class NARTransformerDecoderLayer(TransformerDecoderLayer):
+class NARTransformerDecoderLayer(Module):
     """Represents the FFT Block as described in
     :cite:t:`https://arxiv.org/pdf/1905.09263.pdf`."""
 
+    model_dim: int
     self_attn: MultiheadAttention
     self_attn_dropout: Optional[Dropout]
     self_attn_layer_norm: LayerNorm
@@ -137,9 +132,9 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
         :param conv1d_dropout_p:
             The dropout probability on the outputs of the conv1d block.
         """
-        model_dim = self_attn.model_dim
+        super().__init__()
 
-        super().__init__(model_dim)
+        self.model_dim = self_attn.model_dim
 
         self.self_attn = self_attn
 
@@ -151,7 +146,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
         layer_norm_factory = create_standard_layer_norm
 
         self.self_attn_layer_norm = layer_norm_factory(
-            model_dim, device=device, dtype=dtype
+            self.model_dim, device=device, dtype=dtype
         )
 
         self.conv1d = conv1d
@@ -162,7 +157,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
             self.register_module("conv1d_dropout", None)
 
         self.conv1d_layer_norm = layer_norm_factory(
-            model_dim, device=device, dtype=dtype
+            self.model_dim, device=device, dtype=dtype
         )
 
     @finaloverride
@@ -170,10 +165,6 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
         self,
         seqs: Tensor,
         padding_mask: Optional[PaddingMask],
-        self_attn_mask: Optional[AttentionMask] = None,
-        encoder_output: Optional[Tensor] = None,
-        encoder_padding_mask: Optional[PaddingMask] = None,
-        state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs = self._forward_self_attn(seqs, padding_mask)
 

+ 4 - 3
src/seamless_communication/models/unity/t2u_builder.py

@@ -38,6 +38,7 @@ from fairseq2.models.nllb.loader import NllbTokenizerLoader
 
 
 from seamless_communication.assets import asset_store
+from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
 from seamless_communication.models.unity.nar_decoder_layer import (
     NARTransformerDecoderLayer,
     Conv1dBlock,
@@ -550,21 +551,21 @@ class UnitYNART2UBuilder:
             dtype=self.dtype,
         )
 
-    def build_decoder(self) -> TransformerDecoder:
+    def build_decoder(self) -> NARTransformerDecoder:
         """Build a Transformer decoder."""
 
         num_layers = self.config.num_decoder_layers
 
         layers = [self.build_decoder_layer() for _ in range(num_layers)]
 
-        return StandardTransformerDecoder(
+        return NARTransformerDecoder(
             layers,
             norm_order=TransformerNormOrder.PRE,
             device=self.device,
             dtype=self.dtype,
         )
 
-    def build_decoder_layer(self) -> TransformerDecoderLayer:
+    def build_decoder_layer(self) -> NARTransformerDecoderLayer:
         """Build a Transformer decoder layer."""
 
         assert self.config.nar_decoder_config is not None