Ver Fonte

Separate T2U and NAR T2U builders for UnitY T2U model. (#70)

Kaushik Ram Sadagopan há 1 ano atrás
pai
commit
4d679cb67e

+ 3 - 0
src/seamless_communication/models/unity/__init__.py

@@ -55,6 +55,9 @@ from seamless_communication.models.unity.nar_decoder_layer import (
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UBuilder as UnitYT2UBuilder,
 )
+from seamless_communication.models.unity.t2u_builder import (
+    UnitYNART2UBuilder as UnitYNART2UBuilder,
+)
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UConfig as UnitYT2UConfig,
 )

+ 9 - 5
src/seamless_communication/models/unity/builder.py

@@ -5,8 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from torch.nn import Parameter
-from typing import Optional
+from typing import Union, Optional
 
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
@@ -33,6 +32,7 @@ from seamless_communication.models.unity.adaptor_block import (
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UBuilder,
+    UnitYNART2UBuilder,
     UnitYT2UConfig,
     unity_t2u_archs,
 )
@@ -176,7 +176,7 @@ class UnitYBuilder:
     config: UnitYConfig
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
     mt_model_builder: NllbBuilder
-    t2u_builder: Optional["UnitYT2UBuilder"]
+    t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None]
     device: Optional[Device]
     dtype: Optional[DataType]
 
@@ -185,7 +185,7 @@ class UnitYBuilder:
         config: UnitYConfig,
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
         mt_model_builder: NllbBuilder,
-        t2u_builder: Optional["UnitYT2UBuilder"],
+        t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None],
         *,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
@@ -389,10 +389,14 @@ def create_unity_model(
             config.w2v2_encoder_config, device=device, dtype=dtype
         )
 
+    t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None]
+
     if config.t2u_config is None:
         t2u_builder = None
-    else:
+    elif config.t2u_config.nar_decoder_config is None:
         t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
+    else:
+        t2u_builder = UnitYNART2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
     mt_model_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
     unity_builder = UnitYBuilder(

+ 218 - 73
src/seamless_communication/models/unity/t2u_builder.py

@@ -205,7 +205,7 @@ def _base_nar() -> UnitYT2UConfig:
 
 
 class UnitYT2UBuilder:
-    """Builds modules of an AR or NAR UnitY T2U model.
+    """Builds modules of an autoregressive UnitY T2U model.
 
     To tweak the architecture, you can derive from this class and override the
     corresponding methods.
@@ -234,8 +234,180 @@ class UnitYT2UBuilder:
 
         self.device, self.dtype = device, dtype
 
-    def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]:
-        """Build a model."""
+    def build_model(self) -> UnitYT2UModel:
+        """Build an autoregressive UnitYT2U model."""
+
+        embed_unit = self.build_unit_embedding()
+
+        encoder = self.build_encoder()
+
+        decoder = self.build_decoder()
+
+        final_proj = TiedProjection(embed_unit.weight, bias=None)
+
+        decoder_frontend = self.build_decoder_frontend(embed_unit)
+
+        return UnitYT2UModel(
+            encoder,
+            decoder_frontend,
+            decoder,
+            final_proj,
+            self.config.unit_pad_idx,
+        )
+
+    def build_unit_embedding(self) -> StandardEmbedding:
+        """Build a unit embedding table."""
+
+        return StandardEmbedding(
+            num_embeddings=self.config.unit_vocabulary_size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=self.config.unit_pad_idx,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_encoder(self) -> Optional[TransformerEncoder]:
+        """Build a Transformer encoder."""
+
+        num_layers = self.config.num_encoder_layers
+        if num_layers == 0:
+            return None
+
+        layers = [self.build_encoder_layer() for _ in range(num_layers)]
+
+        return StandardTransformerEncoder(
+            layers,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_encoder_layer(self) -> TransformerEncoderLayer:
+        """Build a Transformer encoder layer."""
+
+        self_attn = self.build_attention(self.config.num_encoder_attn_heads)
+
+        ffn = self.build_ffn()
+
+        return StandardTransformerEncoderLayer(
+            self_attn,
+            ffn,
+            dropout_p=self.config.dropout_p,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder_frontend(self, embed_unit: Embedding) -> TransformerFrontend:
+        """Build a Transformer decoder front-end."""
+
+        pos_encoder = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.unit_max_seq_len,
+            _legacy_pad_idx=self.config.unit_pad_idx,
+            device=self.device,
+        )
+        return TransformerEmbeddingFrontend(
+            embed_unit,
+            pos_encoder,
+            dropout_p=self.config.dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder(self) -> TransformerDecoder:
+        """Build a Transformer decoder."""
+
+        num_layers = self.config.num_decoder_layers
+
+        layers = [self.build_decoder_layer() for _ in range(num_layers)]
+
+        return StandardTransformerDecoder(
+            layers,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder_layer(self) -> TransformerDecoderLayer:
+        """Build a Transformer decoder layer."""
+
+        self_attn = self.build_attention(self.config.num_decoder_attn_heads)
+
+        encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads)
+
+        ffn = self.build_ffn()
+
+        return StandardTransformerDecoderLayer(
+            self_attn,
+            encoder_decoder_attn,
+            ffn,
+            dropout_p=self.config.dropout_p,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
+        """Build a Transformer multi-head attention layer."""
+
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
+
+        return StandardMultiheadAttention(
+            self.config.model_dim,
+            num_heads,
+            sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_ffn(self) -> FeedForwardNetwork:
+        """Build a Transformer feed-forward network."""
+
+        return StandardFeedForwardNetwork(
+            self.config.model_dim,
+            self.config.ffn_inner_dim,
+            bias=True,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+class UnitYNART2UBuilder:
+    """Builds modules of an NAR UnitY T2U model.
+
+    To tweak the architecture, you can derive from this class and override the
+    corresponding methods.
+    """
+
+    config: UnitYT2UConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: UnitYT2UConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> UnitYNART2UModel:
+        """Build a non-autoregressive UnitY T2U model."""
+
         embed_unit = self.build_unit_embedding()
 
         encoder = self.build_encoder()
@@ -244,27 +416,19 @@ class UnitYT2UBuilder:
 
         final_proj = TiedProjection(embed_unit.weight, bias=None)
 
-        if self.config.nar_decoder_config is None:
-            decoder_frontend = self.build_decoder_frontend(embed_unit)
-            return UnitYT2UModel(
-                encoder,
-                decoder_frontend,
-                decoder,
-                final_proj,
-                self.config.unit_pad_idx,
-            )
-        else:
-            nar_decoder_frontend = self.build_nar_decoder_frontend(embed_unit)
-            return UnitYNART2UModel(
-                encoder,
-                nar_decoder_frontend,
-                decoder,
-                final_proj,
-                self.config.unit_pad_idx,
-            )
+        decoder_frontend = self.build_decoder_frontend(embed_unit)
+
+        return UnitYNART2UModel(
+            encoder,
+            decoder_frontend,
+            decoder,
+            final_proj,
+            self.config.unit_pad_idx,
+        )
 
     def build_unit_embedding(self) -> StandardEmbedding:
         """Build a unit embedding table."""
+
         return StandardEmbedding(
             num_embeddings=self.config.unit_vocabulary_size,
             embedding_dim=self.config.model_dim,
@@ -276,6 +440,7 @@ class UnitYT2UBuilder:
 
     def build_encoder(self) -> Optional[TransformerEncoder]:
         """Build a Transformer encoder."""
+
         num_layers = self.config.num_encoder_layers
         if num_layers == 0:
             return None
@@ -291,6 +456,7 @@ class UnitYT2UBuilder:
 
     def build_encoder_layer(self) -> TransformerEncoderLayer:
         """Build a Transformer encoder layer."""
+
         self_attn = self.build_attention(self.config.num_encoder_attn_heads)
 
         ffn = self.build_ffn()
@@ -307,6 +473,8 @@ class UnitYT2UBuilder:
     def build_variance_adaptor(
         self, nar_decoder_frontend_config: NARDecoderFrontendConfig
     ) -> VarianceAdaptor:
+        """Build a variance adaptor module."""
+
         duration_predictor_config = (
             nar_decoder_frontend_config.duration_predictor_config
         )
@@ -327,24 +495,9 @@ class UnitYT2UBuilder:
 
         return variance_adaptor
 
-    def build_decoder_frontend(self, embed_unit: Embedding) -> TransformerFrontend:
-        """Build a Transformer decoder front-end."""
-        pos_encoder = SinusoidalPositionEncoder(
-            self.config.model_dim,
-            self.config.unit_max_seq_len,
-            _legacy_pad_idx=self.config.unit_pad_idx,
-            device=self.device,
-        )
-        return TransformerEmbeddingFrontend(
-            embed_unit,
-            pos_encoder,
-            dropout_p=self.config.dropout_p,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_nar_decoder_frontend(self, embed_unit: Embedding) -> NARDecoderFrontend:
+    def build_decoder_frontend(self, embed_unit: Embedding) -> NARDecoderFrontend:
         """Build a non-autoregressive decoder front-end."""
+
         assert self.config.nar_decoder_config is not None
         assert self.config.nar_decoder_frontend_config is not None
 
@@ -399,6 +552,7 @@ class UnitYT2UBuilder:
 
     def build_decoder(self) -> TransformerDecoder:
         """Build a Transformer decoder."""
+
         num_layers = self.config.num_decoder_layers
 
         layers = [self.build_decoder_layer() for _ in range(num_layers)]
@@ -412,45 +566,32 @@ class UnitYT2UBuilder:
 
     def build_decoder_layer(self) -> TransformerDecoderLayer:
         """Build a Transformer decoder layer."""
+
+        assert self.config.nar_decoder_config is not None
+
         self_attn = self.build_attention(self.config.num_decoder_attn_heads)
 
-        if self.config.nar_decoder_config:
-            conv1d = Conv1dBlock(
-                self.config.model_dim,
-                self.config.nar_decoder_config.conv1d_inner_dim,
-                self.config.nar_decoder_config.conv1d_kernel_size,
-                bias=True,
-                device=self.device,
-                dtype=self.dtype,
-            )
-
-            return NARTransformerDecoderLayer(
-                self_attn,
-                conv1d,
-                dropout_p=self.config.dropout_p,
-                conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
-                device=self.device,
-                dtype=self.dtype,
-            )
-        else:
-            encoder_decoder_attn = self.build_attention(
-                self.config.num_decoder_attn_heads
-            )
-
-            ffn = self.build_ffn()
-
-            return StandardTransformerDecoderLayer(
-                self_attn,
-                encoder_decoder_attn,
-                ffn,
-                dropout_p=self.config.dropout_p,
-                norm_order=TransformerNormOrder.PRE,
-                device=self.device,
-                dtype=self.dtype,
-            )
+        conv1d = Conv1dBlock(
+            self.config.model_dim,
+            self.config.nar_decoder_config.conv1d_inner_dim,
+            self.config.nar_decoder_config.conv1d_kernel_size,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return NARTransformerDecoderLayer(
+            self_attn,
+            conv1d,
+            dropout_p=self.config.dropout_p,
+            conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
 
     def build_attention(self, num_heads: int) -> MultiheadAttention:
         """Build a Transformer multi-head attention layer."""
+
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
 
         return StandardMultiheadAttention(
@@ -463,6 +604,7 @@ class UnitYT2UBuilder:
 
     def build_ffn(self) -> FeedForwardNetwork:
         """Build a Transformer feed-forward network."""
+
         return StandardFeedForwardNetwork(
             self.config.model_dim,
             self.config.ffn_inner_dim,
@@ -487,4 +629,7 @@ def create_unity_t2u_model(
     :param dtype:
         The data type of module parameters and buffers.
     """
-    return UnitYT2UBuilder(config, device=device, dtype=dtype).build_model()
+    if config.nar_decoder_config is None:
+        return UnitYT2UBuilder(config, device=device, dtype=dtype).build_model()
+    else:
+        return UnitYNART2UBuilder(config, device=device, dtype=dtype).build_model()