|
@@ -0,0 +1,266 @@
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
+# All rights reserved.
|
|
|
+#
|
|
|
+# This source code is licensed under the license found in the
|
|
|
+# LICENSE file in the root directory of this source tree.
|
|
|
+
|
|
|
+from dataclasses import dataclass
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+from fairseq2.data import VocabularyInfo
|
|
|
+from fairseq2.models.transformer import (
|
|
|
+ TransformerEmbeddingFrontend,
|
|
|
+ TransformerFrontend,
|
|
|
+)
|
|
|
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
|
|
|
+from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
|
|
|
+from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
|
|
|
+from fairseq2.nn.projection import TiedProjection
|
|
|
+from fairseq2.nn.transformer import (
|
|
|
+ FeedForwardNetwork,
|
|
|
+ MultiheadAttention,
|
|
|
+ StandardFeedForwardNetwork,
|
|
|
+ StandardMultiheadAttention,
|
|
|
+ TransformerNormOrder,
|
|
|
+ create_default_sdpa,
|
|
|
+)
|
|
|
+from fairseq2.typing import DataType, Device
|
|
|
+
|
|
|
+from seamless_communication.models.monotonic_decoder.p_choose import (
|
|
|
+ PChooseLayer,
|
|
|
+)
|
|
|
+from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
|
|
|
+from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
|
|
|
+ MonotonicTransformerDecoder,
|
|
|
+)
|
|
|
+from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import (
|
|
|
+ MonotonicTransformerDecoderLayer,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class MonotonicDecoderConfig:
|
|
|
+ """Holds the configuration of an Monotonic Decoder model."""
|
|
|
+
|
|
|
+ model_dim: int
|
|
|
+ """The dimensionality of the model."""
|
|
|
+
|
|
|
+ max_seq_len: int
|
|
|
+ """The expected maximum sequence length."""
|
|
|
+
|
|
|
+ vocab_info: VocabularyInfo
|
|
|
+ """The vocabulary information."""
|
|
|
+
|
|
|
+ num_decoder_layers: int
|
|
|
+ """The number of Transformer decoder layers."""
|
|
|
+
|
|
|
+ num_decoder_attn_heads: int
|
|
|
+ """The number of attention heads in Transformer decoder layers."""
|
|
|
+
|
|
|
+ ffn_inner_dim: int
|
|
|
+ """The inner dimensionality of Transformer feed-forward networks."""
|
|
|
+
|
|
|
+ dropout_p: float
|
|
|
+ """The dropout probability in Transformer layers."""
|
|
|
+
|
|
|
+ energy_bias_value: float
|
|
|
+ """The value of the energy bias parameter to be added to the
|
|
|
+ monotonic energy in the PChooseLayer."""
|
|
|
+
|
|
|
+ monotonic_temperature: float
|
|
|
+ """The parameter with which to divide the monotonic energy
|
|
|
+ to compute p_choose."""
|
|
|
+
|
|
|
+ num_monotonic_energy_layers: int
|
|
|
+ """The number of layers in the EnergyProjection module."""
|
|
|
+
|
|
|
+ pre_decision_ratio: int
|
|
|
+ """The kernel size and stride of the average pooling
|
|
|
+ in the PChooseLayer."""
|
|
|
+
|
|
|
+
|
|
|
+monotonic_decoder_archs = ArchitectureRegistry[MonotonicDecoderConfig](
|
|
|
+ "monotonic_decoder"
|
|
|
+)
|
|
|
+
|
|
|
+monotonic_decoder_arch = monotonic_decoder_archs.marker
|
|
|
+
|
|
|
+
|
|
|
+@monotonic_decoder_arch("dense_1b")
|
|
|
+def _dense_1b() -> MonotonicDecoderConfig:
|
|
|
+
|
|
|
+ return MonotonicDecoderConfig(
|
|
|
+ model_dim=1024,
|
|
|
+ max_seq_len=4096,
|
|
|
+ vocab_info=VocabularyInfo(
|
|
|
+ size=256102, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=1
|
|
|
+ ),
|
|
|
+ num_decoder_layers=24,
|
|
|
+ num_decoder_attn_heads=16,
|
|
|
+ ffn_inner_dim=1024 * 8,
|
|
|
+ dropout_p=0.1,
|
|
|
+ energy_bias_value=-0.5,
|
|
|
+ monotonic_temperature=0.2,
|
|
|
+ num_monotonic_energy_layers=4,
|
|
|
+ pre_decision_ratio=2,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class MonotonicDecoderBuilder:
|
|
|
+ """Builds modules of a Monotonic Decoder.
|
|
|
+
|
|
|
+ To tweak the architecture, you can derive from this class and override the
|
|
|
+ corresponding methods.
|
|
|
+ """
|
|
|
+
|
|
|
+ config: MonotonicDecoderConfig
|
|
|
+ device: Optional[Device]
|
|
|
+ dtype: Optional[DataType]
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: MonotonicDecoderConfig,
|
|
|
+ *,
|
|
|
+ device: Optional[Device] = None,
|
|
|
+ dtype: Optional[DataType] = None,
|
|
|
+ ) -> None:
|
|
|
+ """
|
|
|
+ :param config:
|
|
|
+ The configuration to use.
|
|
|
+ :param device:
|
|
|
+ The device on which to initialize modules.
|
|
|
+ :param dtype:
|
|
|
+ The data type of module parameters and buffers.
|
|
|
+ """
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ self.device, self.dtype = device, dtype
|
|
|
+
|
|
|
+ def build_model(self) -> MonotonicDecoderModel:
|
|
|
+ text_embed = self.build_embedding()
|
|
|
+
|
|
|
+ text_decoder_frontend = self.build_frontend(text_embed)
|
|
|
+
|
|
|
+ text_decoder = self.build_decoder()
|
|
|
+
|
|
|
+ final_proj = TiedProjection(text_embed.weight, bias=None)
|
|
|
+
|
|
|
+ return MonotonicDecoderModel(
|
|
|
+ text_decoder_frontend,
|
|
|
+ text_decoder,
|
|
|
+ final_proj,
|
|
|
+ )
|
|
|
+
|
|
|
+ def build_embedding(self) -> StandardEmbedding:
|
|
|
+ """Build an embedding table."""
|
|
|
+ return StandardEmbedding(
|
|
|
+ num_embeddings=self.config.vocab_info.size,
|
|
|
+ embedding_dim=self.config.model_dim,
|
|
|
+ pad_idx=self.config.vocab_info.pad_idx,
|
|
|
+ init_fn=init_scaled_embedding,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ def build_frontend(self, embed: Embedding) -> TransformerFrontend:
|
|
|
+ """Build a Transformer decoder front-end."""
|
|
|
+ pos_encoder = SinusoidalPositionEncoder(
|
|
|
+ self.config.model_dim,
|
|
|
+ self.config.max_seq_len,
|
|
|
+ _legacy_pad_idx=self.config.vocab_info.pad_idx,
|
|
|
+ device=self.device,
|
|
|
+ )
|
|
|
+
|
|
|
+ return TransformerEmbeddingFrontend(
|
|
|
+ embed,
|
|
|
+ pos_encoder,
|
|
|
+ dropout_p=self.config.dropout_p,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ def build_decoder(self) -> MonotonicTransformerDecoder:
|
|
|
+ """Build a Transformer decoder."""
|
|
|
+ num_layers = self.config.num_decoder_layers
|
|
|
+
|
|
|
+ layers = [self.build_decoder_layer() for _ in range(num_layers)]
|
|
|
+
|
|
|
+ return MonotonicTransformerDecoder(
|
|
|
+ layers,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ def build_decoder_layer(self) -> MonotonicTransformerDecoderLayer:
|
|
|
+ """Build a Transformer decoder layer."""
|
|
|
+ self_attn = self.build_attention(self.config.num_decoder_attn_heads)
|
|
|
+
|
|
|
+ encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads)
|
|
|
+
|
|
|
+ p_choose_layer = self.build_p_choose_layer(self.config.num_decoder_attn_heads)
|
|
|
+
|
|
|
+ ffn = self.build_ffn()
|
|
|
+
|
|
|
+ return MonotonicTransformerDecoderLayer(
|
|
|
+ self_attn,
|
|
|
+ encoder_decoder_attn,
|
|
|
+ p_choose_layer,
|
|
|
+ ffn,
|
|
|
+ dropout_p=self.config.dropout_p,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ def build_attention(self, num_heads: int) -> MultiheadAttention:
|
|
|
+ """Build a Transformer multi-head attention layer."""
|
|
|
+ sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
|
|
|
+
|
|
|
+ return StandardMultiheadAttention(
|
|
|
+ self.config.model_dim,
|
|
|
+ num_heads,
|
|
|
+ sdpa=sdpa,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ def build_p_choose_layer(self, num_heads: int) -> PChooseLayer:
|
|
|
+ """Build a PChoose layer."""
|
|
|
+ return PChooseLayer(
|
|
|
+ self.config.model_dim,
|
|
|
+ num_heads,
|
|
|
+ self.config.energy_bias_value,
|
|
|
+ self.config.monotonic_temperature,
|
|
|
+ self.config.num_monotonic_energy_layers,
|
|
|
+ self.config.pre_decision_ratio,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ def build_ffn(self) -> FeedForwardNetwork:
|
|
|
+ """Build a Transformer feed-forward network."""
|
|
|
+ return StandardFeedForwardNetwork(
|
|
|
+ self.config.model_dim,
|
|
|
+ self.config.ffn_inner_dim,
|
|
|
+ bias=True,
|
|
|
+ norm_order=TransformerNormOrder.PRE,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def create_monotonic_decoder_model(
|
|
|
+ config: MonotonicDecoderConfig,
|
|
|
+ *,
|
|
|
+ device: Optional[Device] = None,
|
|
|
+ dtype: Optional[DataType] = None,
|
|
|
+) -> MonotonicDecoderModel:
|
|
|
+ """Create an Monotonic Decoder model.
|
|
|
+
|
|
|
+ :param config:
|
|
|
+ The configuration to use.
|
|
|
+ :param device:
|
|
|
+ The device on which to initialize modules.
|
|
|
+ :param dtype:
|
|
|
+ The data type of module parameters and buffers.
|
|
|
+ """
|
|
|
+ return MonotonicDecoderBuilder(config, device=device, dtype=dtype).build_model()
|