Selaa lähdekoodia

Separate T2U and NAR T2U builders for UnitY T2U model. (#70)

Kaushik Ram Sadagopan 1 vuosi sitten
vanhempi
commit
4d679cb67e

+ 3 - 0
src/seamless_communication/models/unity/__init__.py

@@ -55,6 +55,9 @@ from seamless_communication.models.unity.nar_decoder_layer import (
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UBuilder as UnitYT2UBuilder,
     UnitYT2UBuilder as UnitYT2UBuilder,
 )
 )
+from seamless_communication.models.unity.t2u_builder import (
+    UnitYNART2UBuilder as UnitYNART2UBuilder,
+)
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UConfig as UnitYT2UConfig,
     UnitYT2UConfig as UnitYT2UConfig,
 )
 )

+ 9 - 5
src/seamless_communication/models/unity/builder.py

@@ -5,8 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
-from torch.nn import Parameter
-from typing import Optional
+from typing import Union, Optional
 
 
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
@@ -33,6 +32,7 @@ from seamless_communication.models.unity.adaptor_block import (
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.t2u_builder import (
 from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UBuilder,
     UnitYT2UBuilder,
+    UnitYNART2UBuilder,
     UnitYT2UConfig,
     UnitYT2UConfig,
     unity_t2u_archs,
     unity_t2u_archs,
 )
 )
@@ -176,7 +176,7 @@ class UnitYBuilder:
     config: UnitYConfig
     config: UnitYConfig
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
     mt_model_builder: NllbBuilder
     mt_model_builder: NllbBuilder
-    t2u_builder: Optional["UnitYT2UBuilder"]
+    t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None]
     device: Optional[Device]
     device: Optional[Device]
     dtype: Optional[DataType]
     dtype: Optional[DataType]
 
 
@@ -185,7 +185,7 @@ class UnitYBuilder:
         config: UnitYConfig,
         config: UnitYConfig,
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
         mt_model_builder: NllbBuilder,
         mt_model_builder: NllbBuilder,
-        t2u_builder: Optional["UnitYT2UBuilder"],
+        t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None],
         *,
         *,
         device: Optional[Device] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
         dtype: Optional[DataType] = None,
@@ -389,10 +389,14 @@ def create_unity_model(
             config.w2v2_encoder_config, device=device, dtype=dtype
             config.w2v2_encoder_config, device=device, dtype=dtype
         )
         )
 
 
+    t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None]
+
     if config.t2u_config is None:
     if config.t2u_config is None:
         t2u_builder = None
         t2u_builder = None
-    else:
+    elif config.t2u_config.nar_decoder_config is None:
         t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
         t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
+    else:
+        t2u_builder = UnitYNART2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
 
     mt_model_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
     mt_model_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
     unity_builder = UnitYBuilder(
     unity_builder = UnitYBuilder(

+ 218 - 73
src/seamless_communication/models/unity/t2u_builder.py

@@ -205,7 +205,7 @@ def _base_nar() -> UnitYT2UConfig:
 
 
 
 
 class UnitYT2UBuilder:
 class UnitYT2UBuilder:
-    """Builds modules of an AR or NAR UnitY T2U model.
+    """Builds modules of an autoregressive UnitY T2U model.
 
 
     To tweak the architecture, you can derive from this class and override the
     To tweak the architecture, you can derive from this class and override the
     corresponding methods.
     corresponding methods.
@@ -234,8 +234,180 @@ class UnitYT2UBuilder:
 
 
         self.device, self.dtype = device, dtype
         self.device, self.dtype = device, dtype
 
 
-    def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]:
-        """Build a model."""
+    def build_model(self) -> UnitYT2UModel:
+        """Build an autoregressive UnitYT2U model."""
+
+        embed_unit = self.build_unit_embedding()
+
+        encoder = self.build_encoder()
+
+        decoder = self.build_decoder()
+
+        final_proj = TiedProjection(embed_unit.weight, bias=None)
+
+        decoder_frontend = self.build_decoder_frontend(embed_unit)
+
+        return UnitYT2UModel(
+            encoder,
+            decoder_frontend,
+            decoder,
+            final_proj,
+            self.config.unit_pad_idx,
+        )
+
+    def build_unit_embedding(self) -> StandardEmbedding:
+        """Build a unit embedding table."""
+
+        return StandardEmbedding(
+            num_embeddings=self.config.unit_vocabulary_size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=self.config.unit_pad_idx,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_encoder(self) -> Optional[TransformerEncoder]:
+        """Build a Transformer encoder."""
+
+        num_layers = self.config.num_encoder_layers
+        if num_layers == 0:
+            return None
+
+        layers = [self.build_encoder_layer() for _ in range(num_layers)]
+
+        return StandardTransformerEncoder(
+            layers,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_encoder_layer(self) -> TransformerEncoderLayer:
+        """Build a Transformer encoder layer."""
+
+        self_attn = self.build_attention(self.config.num_encoder_attn_heads)
+
+        ffn = self.build_ffn()
+
+        return StandardTransformerEncoderLayer(
+            self_attn,
+            ffn,
+            dropout_p=self.config.dropout_p,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder_frontend(self, embed_unit: Embedding) -> TransformerFrontend:
+        """Build a Transformer decoder front-end."""
+
+        pos_encoder = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.unit_max_seq_len,
+            _legacy_pad_idx=self.config.unit_pad_idx,
+            device=self.device,
+        )
+        return TransformerEmbeddingFrontend(
+            embed_unit,
+            pos_encoder,
+            dropout_p=self.config.dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder(self) -> TransformerDecoder:
+        """Build a Transformer decoder."""
+
+        num_layers = self.config.num_decoder_layers
+
+        layers = [self.build_decoder_layer() for _ in range(num_layers)]
+
+        return StandardTransformerDecoder(
+            layers,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder_layer(self) -> TransformerDecoderLayer:
+        """Build a Transformer decoder layer."""
+
+        self_attn = self.build_attention(self.config.num_decoder_attn_heads)
+
+        encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads)
+
+        ffn = self.build_ffn()
+
+        return StandardTransformerDecoderLayer(
+            self_attn,
+            encoder_decoder_attn,
+            ffn,
+            dropout_p=self.config.dropout_p,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
+        """Build a Transformer multi-head attention layer."""
+
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
+
+        return StandardMultiheadAttention(
+            self.config.model_dim,
+            num_heads,
+            sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_ffn(self) -> FeedForwardNetwork:
+        """Build a Transformer feed-forward network."""
+
+        return StandardFeedForwardNetwork(
+            self.config.model_dim,
+            self.config.ffn_inner_dim,
+            bias=True,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+class UnitYNART2UBuilder:
+    """Builds modules of an NAR UnitY T2U model.
+
+    To tweak the architecture, you can derive from this class and override the
+    corresponding methods.
+    """
+
+    config: UnitYT2UConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: UnitYT2UConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> UnitYNART2UModel:
+        """Build a non-autoregressive UnitY T2U model."""
+
         embed_unit = self.build_unit_embedding()
         embed_unit = self.build_unit_embedding()
 
 
         encoder = self.build_encoder()
         encoder = self.build_encoder()
@@ -244,27 +416,19 @@ class UnitYT2UBuilder:
 
 
         final_proj = TiedProjection(embed_unit.weight, bias=None)
         final_proj = TiedProjection(embed_unit.weight, bias=None)
 
 
-        if self.config.nar_decoder_config is None:
-            decoder_frontend = self.build_decoder_frontend(embed_unit)
-            return UnitYT2UModel(
-                encoder,
-                decoder_frontend,
-                decoder,
-                final_proj,
-                self.config.unit_pad_idx,
-            )
-        else:
-            nar_decoder_frontend = self.build_nar_decoder_frontend(embed_unit)
-            return UnitYNART2UModel(
-                encoder,
-                nar_decoder_frontend,
-                decoder,
-                final_proj,
-                self.config.unit_pad_idx,
-            )
+        decoder_frontend = self.build_decoder_frontend(embed_unit)
+
+        return UnitYNART2UModel(
+            encoder,
+            decoder_frontend,
+            decoder,
+            final_proj,
+            self.config.unit_pad_idx,
+        )
 
 
     def build_unit_embedding(self) -> StandardEmbedding:
     def build_unit_embedding(self) -> StandardEmbedding:
         """Build a unit embedding table."""
         """Build a unit embedding table."""
+
         return StandardEmbedding(
         return StandardEmbedding(
             num_embeddings=self.config.unit_vocabulary_size,
             num_embeddings=self.config.unit_vocabulary_size,
             embedding_dim=self.config.model_dim,
             embedding_dim=self.config.model_dim,
@@ -276,6 +440,7 @@ class UnitYT2UBuilder:
 
 
     def build_encoder(self) -> Optional[TransformerEncoder]:
     def build_encoder(self) -> Optional[TransformerEncoder]:
         """Build a Transformer encoder."""
         """Build a Transformer encoder."""
+
         num_layers = self.config.num_encoder_layers
         num_layers = self.config.num_encoder_layers
         if num_layers == 0:
         if num_layers == 0:
             return None
             return None
@@ -291,6 +456,7 @@ class UnitYT2UBuilder:
 
 
     def build_encoder_layer(self) -> TransformerEncoderLayer:
     def build_encoder_layer(self) -> TransformerEncoderLayer:
         """Build a Transformer encoder layer."""
         """Build a Transformer encoder layer."""
+
         self_attn = self.build_attention(self.config.num_encoder_attn_heads)
         self_attn = self.build_attention(self.config.num_encoder_attn_heads)
 
 
         ffn = self.build_ffn()
         ffn = self.build_ffn()
@@ -307,6 +473,8 @@ class UnitYT2UBuilder:
     def build_variance_adaptor(
     def build_variance_adaptor(
         self, nar_decoder_frontend_config: NARDecoderFrontendConfig
         self, nar_decoder_frontend_config: NARDecoderFrontendConfig
     ) -> VarianceAdaptor:
     ) -> VarianceAdaptor:
+        """Build a variance adaptor module."""
+
         duration_predictor_config = (
         duration_predictor_config = (
             nar_decoder_frontend_config.duration_predictor_config
             nar_decoder_frontend_config.duration_predictor_config
         )
         )
@@ -327,24 +495,9 @@ class UnitYT2UBuilder:
 
 
         return variance_adaptor
         return variance_adaptor
 
 
-    def build_decoder_frontend(self, embed_unit: Embedding) -> TransformerFrontend:
-        """Build a Transformer decoder front-end."""
-        pos_encoder = SinusoidalPositionEncoder(
-            self.config.model_dim,
-            self.config.unit_max_seq_len,
-            _legacy_pad_idx=self.config.unit_pad_idx,
-            device=self.device,
-        )
-        return TransformerEmbeddingFrontend(
-            embed_unit,
-            pos_encoder,
-            dropout_p=self.config.dropout_p,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_nar_decoder_frontend(self, embed_unit: Embedding) -> NARDecoderFrontend:
+    def build_decoder_frontend(self, embed_unit: Embedding) -> NARDecoderFrontend:
         """Build a non-autoregressive decoder front-end."""
         """Build a non-autoregressive decoder front-end."""
+
         assert self.config.nar_decoder_config is not None
         assert self.config.nar_decoder_config is not None
         assert self.config.nar_decoder_frontend_config is not None
         assert self.config.nar_decoder_frontend_config is not None
 
 
@@ -399,6 +552,7 @@ class UnitYT2UBuilder:
 
 
     def build_decoder(self) -> TransformerDecoder:
     def build_decoder(self) -> TransformerDecoder:
         """Build a Transformer decoder."""
         """Build a Transformer decoder."""
+
         num_layers = self.config.num_decoder_layers
         num_layers = self.config.num_decoder_layers
 
 
         layers = [self.build_decoder_layer() for _ in range(num_layers)]
         layers = [self.build_decoder_layer() for _ in range(num_layers)]
@@ -412,45 +566,32 @@ class UnitYT2UBuilder:
 
 
     def build_decoder_layer(self) -> TransformerDecoderLayer:
     def build_decoder_layer(self) -> TransformerDecoderLayer:
         """Build a Transformer decoder layer."""
         """Build a Transformer decoder layer."""
+
+        assert self.config.nar_decoder_config is not None
+
         self_attn = self.build_attention(self.config.num_decoder_attn_heads)
         self_attn = self.build_attention(self.config.num_decoder_attn_heads)
 
 
-        if self.config.nar_decoder_config:
-            conv1d = Conv1dBlock(
-                self.config.model_dim,
-                self.config.nar_decoder_config.conv1d_inner_dim,
-                self.config.nar_decoder_config.conv1d_kernel_size,
-                bias=True,
-                device=self.device,
-                dtype=self.dtype,
-            )
-
-            return NARTransformerDecoderLayer(
-                self_attn,
-                conv1d,
-                dropout_p=self.config.dropout_p,
-                conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
-                device=self.device,
-                dtype=self.dtype,
-            )
-        else:
-            encoder_decoder_attn = self.build_attention(
-                self.config.num_decoder_attn_heads
-            )
-
-            ffn = self.build_ffn()
-
-            return StandardTransformerDecoderLayer(
-                self_attn,
-                encoder_decoder_attn,
-                ffn,
-                dropout_p=self.config.dropout_p,
-                norm_order=TransformerNormOrder.PRE,
-                device=self.device,
-                dtype=self.dtype,
-            )
+        conv1d = Conv1dBlock(
+            self.config.model_dim,
+            self.config.nar_decoder_config.conv1d_inner_dim,
+            self.config.nar_decoder_config.conv1d_kernel_size,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return NARTransformerDecoderLayer(
+            self_attn,
+            conv1d,
+            dropout_p=self.config.dropout_p,
+            conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
 
 
     def build_attention(self, num_heads: int) -> MultiheadAttention:
     def build_attention(self, num_heads: int) -> MultiheadAttention:
         """Build a Transformer multi-head attention layer."""
         """Build a Transformer multi-head attention layer."""
+
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
 
 
         return StandardMultiheadAttention(
         return StandardMultiheadAttention(
@@ -463,6 +604,7 @@ class UnitYT2UBuilder:
 
 
     def build_ffn(self) -> FeedForwardNetwork:
     def build_ffn(self) -> FeedForwardNetwork:
         """Build a Transformer feed-forward network."""
         """Build a Transformer feed-forward network."""
+
         return StandardFeedForwardNetwork(
         return StandardFeedForwardNetwork(
             self.config.model_dim,
             self.config.model_dim,
             self.config.ffn_inner_dim,
             self.config.ffn_inner_dim,
@@ -487,4 +629,7 @@ def create_unity_t2u_model(
     :param dtype:
     :param dtype:
         The data type of module parameters and buffers.
         The data type of module parameters and buffers.
     """
     """
-    return UnitYT2UBuilder(config, device=device, dtype=dtype).build_model()
+    if config.nar_decoder_config is None:
+        return UnitYT2UBuilder(config, device=device, dtype=dtype).build_model()
+    else:
+        return UnitYNART2UBuilder(config, device=device, dtype=dtype).build_model()