Эх сурвалжийг харах

Introduce monotonic_decoder. (#73)

* Adding monotonic transformer decoder model with its decoder, layer, attention.

* Preliminary un-verified full implementation of monotonic transformer decoder.

* Unverified - integrate into unity builder, write builder for monotonic_transformer_decoder.

* Successfully load the mma_s2t checkpoint.

* Complete implementation with successful forward pass.

* Refactor monotonic_decoder to be separate from unity.

* Separate unity & monotonic decoder checkpoints.

* Addressing review comments on separating monotonic_decoder from nllb.

* Fix bugs verifying parity.

---------

Co-authored-by: Can Balioglu <cbalioglu@users.noreply.github.com>
Kaushik Ram Sadagopan 1 жил өмнө
parent
commit
d6ff19b20b

+ 10 - 0
src/seamless_communication/cards/monotonic_decoder.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: monotonic_decoder
+model_type: monotonic_decoder
+model_arch: dense_1b
+checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/monotonic_decoder.pt"

+ 10 - 0
src/seamless_communication/cards/unity_sans_decoder.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: unity_sans_decoder
+base: unity_nllb-100
+model_arch: base_v2
+checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/unity_sans_decoder.pt"

+ 27 - 0
src/seamless_communication/models/monotonic_decoder/__init__.py

@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.models.monotonic_decoder.builder import (
+    MonotonicDecoderBuilder as MonotonicDecoderBuilder,
+)
+from seamless_communication.models.monotonic_decoder.builder import (
+    MonotonicDecoderConfig as MonotonicDecoderConfig,
+)
+from seamless_communication.models.monotonic_decoder.builder import (
+    monotonic_decoder_archs as monotonic_decoder_archs,
+)
+from seamless_communication.models.monotonic_decoder.loader import (
+    load_monotonic_decoder_model as load_monotonic_decoder_model,
+)
+from seamless_communication.models.monotonic_decoder.loader import (
+    load_monotonic_decoder_config as load_monotonic_decoder_config,
+)
+from seamless_communication.models.monotonic_decoder.builder import (
+    create_monotonic_decoder_model as create_monotonic_decoder_model,
+)
+from seamless_communication.models.monotonic_decoder.builder import (
+    monotonic_decoder_archs as monotonic_decoder_archs,
+)

+ 266 - 0
src/seamless_communication/models/monotonic_decoder/builder.py

@@ -0,0 +1,266 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import Optional
+
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.transformer import (
+    TransformerEmbeddingFrontend,
+    TransformerFrontend,
+)
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
+from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
+from fairseq2.nn.projection import TiedProjection
+from fairseq2.nn.transformer import (
+    FeedForwardNetwork,
+    MultiheadAttention,
+    StandardFeedForwardNetwork,
+    StandardMultiheadAttention,
+    TransformerNormOrder,
+    create_default_sdpa,
+)
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.models.monotonic_decoder.p_choose import (
+    PChooseLayer,
+)
+from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
+from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
+    MonotonicTransformerDecoder,
+)
+from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import (
+    MonotonicTransformerDecoderLayer,
+)
+
+
+@dataclass
+class MonotonicDecoderConfig:
+    """Holds the configuration of an Monotonic Decoder model."""
+
+    model_dim: int
+    """The dimensionality of the model."""
+
+    max_seq_len: int
+    """The expected maximum sequence length."""
+
+    vocab_info: VocabularyInfo
+    """The vocabulary information."""
+
+    num_decoder_layers: int
+    """The number of Transformer decoder layers."""
+
+    num_decoder_attn_heads: int
+    """The number of attention heads in Transformer decoder layers."""
+
+    ffn_inner_dim: int
+    """The inner dimensionality of Transformer feed-forward networks."""
+
+    dropout_p: float
+    """The dropout probability in Transformer layers."""
+
+    energy_bias_value: float
+    """The value of the energy bias parameter to be added to the
+    monotonic energy in the PChooseLayer."""
+
+    monotonic_temperature: float
+    """The parameter with which to divide the monotonic energy
+    to compute p_choose."""
+
+    num_monotonic_energy_layers: int
+    """The number of layers in the EnergyProjection module."""
+
+    pre_decision_ratio: int
+    """The kernel size and stride of the average pooling
+    in the PChooseLayer."""
+
+
+monotonic_decoder_archs = ArchitectureRegistry[MonotonicDecoderConfig](
+    "monotonic_decoder"
+)
+
+monotonic_decoder_arch = monotonic_decoder_archs.marker
+
+
+@monotonic_decoder_arch("dense_1b")
+def _dense_1b() -> MonotonicDecoderConfig:
+
+    return MonotonicDecoderConfig(
+        model_dim=1024,
+        max_seq_len=4096,
+        vocab_info=VocabularyInfo(
+            size=256102, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=1
+        ),
+        num_decoder_layers=24,
+        num_decoder_attn_heads=16,
+        ffn_inner_dim=1024 * 8,
+        dropout_p=0.1,
+        energy_bias_value=-0.5,
+        monotonic_temperature=0.2,
+        num_monotonic_energy_layers=4,
+        pre_decision_ratio=2,
+    )
+
+
+class MonotonicDecoderBuilder:
+    """Builds modules of a Monotonic Decoder.
+
+    To tweak the architecture, you can derive from this class and override the
+    corresponding methods.
+    """
+
+    config: MonotonicDecoderConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: MonotonicDecoderConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> MonotonicDecoderModel:
+        text_embed = self.build_embedding()
+
+        text_decoder_frontend = self.build_frontend(text_embed)
+
+        text_decoder = self.build_decoder()
+
+        final_proj = TiedProjection(text_embed.weight, bias=None)
+
+        return MonotonicDecoderModel(
+            text_decoder_frontend,
+            text_decoder,
+            final_proj,
+        )
+
+    def build_embedding(self) -> StandardEmbedding:
+        """Build an embedding table."""
+        return StandardEmbedding(
+            num_embeddings=self.config.vocab_info.size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=self.config.vocab_info.pad_idx,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_frontend(self, embed: Embedding) -> TransformerFrontend:
+        """Build a Transformer decoder front-end."""
+        pos_encoder = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.max_seq_len,
+            _legacy_pad_idx=self.config.vocab_info.pad_idx,
+            device=self.device,
+        )
+
+        return TransformerEmbeddingFrontend(
+            embed,
+            pos_encoder,
+            dropout_p=self.config.dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder(self) -> MonotonicTransformerDecoder:
+        """Build a Transformer decoder."""
+        num_layers = self.config.num_decoder_layers
+
+        layers = [self.build_decoder_layer() for _ in range(num_layers)]
+
+        return MonotonicTransformerDecoder(
+            layers,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder_layer(self) -> MonotonicTransformerDecoderLayer:
+        """Build a Transformer decoder layer."""
+        self_attn = self.build_attention(self.config.num_decoder_attn_heads)
+
+        encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads)
+
+        p_choose_layer = self.build_p_choose_layer(self.config.num_decoder_attn_heads)
+
+        ffn = self.build_ffn()
+
+        return MonotonicTransformerDecoderLayer(
+            self_attn,
+            encoder_decoder_attn,
+            p_choose_layer,
+            ffn,
+            dropout_p=self.config.dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
+        """Build a Transformer multi-head attention layer."""
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
+
+        return StandardMultiheadAttention(
+            self.config.model_dim,
+            num_heads,
+            sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_p_choose_layer(self, num_heads: int) -> PChooseLayer:
+        """Build a PChoose layer."""
+        return PChooseLayer(
+            self.config.model_dim,
+            num_heads,
+            self.config.energy_bias_value,
+            self.config.monotonic_temperature,
+            self.config.num_monotonic_energy_layers,
+            self.config.pre_decision_ratio,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_ffn(self) -> FeedForwardNetwork:
+        """Build a Transformer feed-forward network."""
+        return StandardFeedForwardNetwork(
+            self.config.model_dim,
+            self.config.ffn_inner_dim,
+            bias=True,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+def create_monotonic_decoder_model(
+    config: MonotonicDecoderConfig,
+    *,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> MonotonicDecoderModel:
+    """Create an Monotonic Decoder model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+    return MonotonicDecoderBuilder(config, device=device, dtype=dtype).build_model()

+ 106 - 0
src/seamless_communication/models/monotonic_decoder/loader.py

@@ -0,0 +1,106 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Dict, Mapping, final
+
+import torch
+
+from fairseq2.assets import (
+    asset_store,
+    download_manager,
+)
+from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
+from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader
+from fairseq2.typing import finaloverride
+
+from seamless_communication.models.monotonic_decoder.builder import (
+    MonotonicDecoderConfig,
+    create_monotonic_decoder_model,
+    monotonic_decoder_archs,
+)
+from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
+
+
+@final
+class MonotonicDecoderLoader(
+    ModelLoader[MonotonicDecoderModel, MonotonicDecoderConfig]
+):
+    """Loads NLLB models."""
+
+    @finaloverride
+    def _convert_checkpoint(
+        self, checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
+    ) -> Mapping[str, Any]:
+        state_dict = checkpoint["model"]
+
+        # Check if we have a fairseq2 checkpoint.
+        if "decoder_frontend.embed_weight" in state_dict:
+            return checkpoint
+
+        key_map = self._fairseq_key_map()
+
+        # Convert to fairseq2.
+        checkpoint = upgrade_fairseq_checkpoint(checkpoint, key_map)
+
+        state_dict = checkpoint["model"]
+
+        embeds = state_dict["final_proj.weight"]
+
+        # fairseq had a bug that accidentally introduced a dummy token in the
+        # embedding table of NLLB-100. We just discard it.
+        if embeds.size(0) == 256103:  # means NLLB-100
+            embeds = embeds[:-1]
+
+            state_dict["final_proj.weight"] = embeds
+
+        # fairseq checkpoints have duplicate embedding weights. Ensure that we
+        # use a single embedding table in fairseq2.
+        state_dict["text_decoder_frontend.embed.weight"] = embeds
+
+        # The embedding positions of the control symbols in fairseq's dict do
+        # not match the SentencePiece model of the tokenizer.
+        with torch.inference_mode():
+            # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
+            embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
+
+        return checkpoint
+
+    @staticmethod
+    def _fairseq_key_map() -> Dict[str, str]:
+        return {
+            # fmt: off
+            # Text Decoder
+            r"^decoder\.embed_tokens\.":                                            r"text_decoder_frontend.embed.",
+            r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":                   r"text_decoder.layers.\1.self_attn.output_proj.",
+            r"^decoder\.layers\.([0-9]+)\.self_attn\.":                             r"text_decoder.layers.\1.self_attn.",
+            r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":                  r"text_decoder.layers.\1.self_attn_layer_norm.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":                r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.energy_bias":               r"text_decoder.layers.\1.p_choose_layer.energy_bias",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.source_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.k_energy_proj.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.target_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.q_energy_proj.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":                          r"text_decoder.layers.\1.encoder_decoder_attn.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.":               r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+            r"^decoder\.layers\.([0-9]+)\.fc1\.":                                   r"text_decoder.layers.\1.ffn.inner_proj.",
+            r"^decoder\.layers\.([0-9]+)\.fc2\.":                                   r"text_decoder.layers.\1.ffn.output_proj.",
+            r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":                      r"text_decoder.layers.\1.ffn_layer_norm.",
+            r"^decoder\.layer_norm\.":                                              r"text_decoder.layer_norm.",
+            r"^decoder\.output_projection\.":                                       r"final_proj.",
+            # fmt: on
+        }
+
+
+load_monotonic_decoder_model = MonotonicDecoderLoader(
+    asset_store,
+    download_manager,
+    create_monotonic_decoder_model,
+    monotonic_decoder_archs,
+    restrict_checkpoints=False,
+)
+
+
+load_monotonic_decoder_config = ModelConfigLoader[MonotonicDecoderConfig](
+    asset_store, monotonic_decoder_archs
+)

+ 68 - 0
src/seamless_communication/models/monotonic_decoder/model.py

@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from overrides import final as finaloverride
+from typing import Optional, Tuple, final
+
+
+from torch import Tensor
+from torch.nn import Module
+from fairseq2.nn.incremental_state import IncrementalStateBag
+from fairseq2.models.transformer.frontend import TransformerFrontend
+
+from fairseq2.nn.projection import Projection
+from fairseq2.nn.padding import PaddingMask
+
+from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
+    MonotonicTransformerDecoder,
+)
+
+
+@final
+class MonotonicDecoderModel(Module):
+    text_decoder_frontend: TransformerFrontend
+    text_decoder: MonotonicTransformerDecoder
+    final_proj: Projection
+
+    def __init__(
+        self,
+        text_decoder_frontend: TransformerFrontend,
+        text_decoder: MonotonicTransformerDecoder,
+        final_proj: Projection,
+    ) -> None:
+        super().__init__()
+
+        self.text_decoder_frontend = text_decoder_frontend
+        self.text_decoder = text_decoder
+        self.final_proj = final_proj
+
+    @finaloverride
+    def decode(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        encoder_output: Tensor,
+        encoder_padding_mask: Optional[PaddingMask],
+        *,
+        state_bag: Optional[IncrementalStateBag] = None,
+    ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
+        seqs, padding_mask = self.text_decoder_frontend(
+            seqs, padding_mask, state_bag=state_bag
+        )
+
+        return self.text_decoder(  # type: ignore[no-any-return]
+            seqs,
+            padding_mask,
+            encoder_output,
+            encoder_padding_mask,
+            state_bag=state_bag,
+        )
+
+    @finaloverride
+    def project(self, decoder_output: Tensor) -> Tensor:
+        logits = self.final_proj(decoder_output)
+
+        return logits  # type: ignore[no-any-return]

+ 98 - 0
src/seamless_communication/models/monotonic_decoder/monotonic_decoder.py

@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Iterable, List, Optional, Tuple, final
+
+import torch
+from fairseq2.nn.incremental_state import IncrementalStateBag
+from fairseq2.nn.module_list import ModuleList
+from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.transformer import (
+    AttentionMaskFactory,
+    CausalAttentionMaskFactory,
+    create_standard_layer_norm,
+)
+from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Module
+
+from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import (
+    MonotonicTransformerDecoderLayer,
+)
+
+
+@final
+class MonotonicTransformerDecoder(Module):
+    """Represents a Monotonic Transformer decoder."""
+
+    model_dim: int
+    self_attn_mask_factory: AttentionMaskFactory
+    layers: ModuleList
+    layer_norm: LayerNorm
+
+    def __init__(
+        self,
+        layers: Iterable[MonotonicTransformerDecoderLayer],
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param layers:
+            The decoder layers.
+        """
+        super().__init__()
+
+        layer_list = ModuleList(layers)
+
+        if not layer_list:
+            raise ValueError("`layers` must be non-empty.")
+
+        self.model_dim = layer_list[0].model_dim
+
+        self.self_attn_mask_factory = CausalAttentionMaskFactory()
+
+        self.layers = layer_list
+
+        self.layer_norm = create_standard_layer_norm(
+            self.model_dim, device=device, dtype=dtype
+        )
+
+    @finaloverride
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        encoder_output: Optional[Tensor] = None,
+        encoder_padding_mask: Optional[PaddingMask] = None,
+        *,
+        state_bag: Optional[IncrementalStateBag] = None,
+    ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
+        self_attn_mask = self.self_attn_mask_factory(
+            seqs, keys=seqs, training=self.training, state_bag=state_bag
+        )
+
+        p_choose_list: List[Tensor] = []
+
+        for layer in self.layers.drop_iter():
+            seqs, padding_mask, p_choose = layer(
+                seqs,
+                padding_mask,
+                self_attn_mask,
+                encoder_output,
+                encoder_padding_mask,
+                state_bag=state_bag,
+            )
+            p_choose_list.append(p_choose)
+
+        seqs = self.layer_norm(seqs)
+
+        p_choose = torch.cat(p_choose_list, dim=0)
+
+        p_choose = p_choose.flatten(0, 1)
+
+        return seqs, padding_mask, p_choose

+ 201 - 0
src/seamless_communication/models/monotonic_decoder/monotonic_decoder_layer.py

@@ -0,0 +1,201 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple, final
+
+from fairseq2.nn.incremental_state import IncrementalStateBag
+from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.transformer import (
+    AttentionMask,
+    FeedForwardNetwork,
+    MultiheadAttention,
+    create_standard_layer_norm,
+)
+from fairseq2.typing import DataType, Device, finaloverride
+from torch import Tensor
+from torch.nn import Dropout, Module
+
+from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer
+
+
+@final
+class MonotonicTransformerDecoderLayer(Module):
+    """Represents a Monotonic Transformer decoder layer."""
+
+    self_attn: MultiheadAttention
+    self_attn_dropout: Optional[Dropout]
+    self_attn_layer_norm: LayerNorm
+    encoder_decoder_attn: MultiheadAttention
+    encoder_decoder_attn_dropout: Optional[Dropout]
+    encoder_decoder_attn_layer_norm: LayerNorm
+    p_choose_layer: PChooseLayer
+    ffn: FeedForwardNetwork
+    ffn_dropout: Optional[Dropout]
+    ffn_layer_norm: LayerNorm
+
+    def __init__(
+        self,
+        self_attn: MultiheadAttention,
+        encoder_decoder_attn: MultiheadAttention,
+        p_choose_layer: PChooseLayer,
+        ffn: FeedForwardNetwork,
+        *,
+        dropout_p: float = 0.1,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param self_attn:
+            The self attention layer.
+        :param encoder_decoder_attn:
+            The encoder-decoder attention layer.
+        :param ffn:
+            The feed-forward network.
+        :param dropout_p:
+            The dropout probability on outputs of the attention layers and the
+            feed-forward network.
+        """
+        super().__init__()
+
+        self.model_dim = self_attn.model_dim
+
+        self_attn_layer_norm = create_standard_layer_norm(
+            self.model_dim, device=device, dtype=dtype
+        )
+
+        self.self_attn_layer_norm = self_attn_layer_norm
+
+        self.self_attn = self_attn
+
+        if dropout_p > 0.0:
+            self.self_attn_dropout = Dropout(dropout_p)
+        else:
+            self.register_module("self_attn_dropout", None)
+
+        encoder_decoder_attn_layer_norm = create_standard_layer_norm(
+            self.model_dim, device=device, dtype=dtype
+        )
+
+        self.encoder_decoder_attn_layer_norm = encoder_decoder_attn_layer_norm
+
+        self.encoder_decoder_attn = encoder_decoder_attn
+
+        if dropout_p > 0.0:
+            self.encoder_decoder_attn_dropout = Dropout(dropout_p)
+        else:
+            self.register_module("encoder_decoder_attn_dropout", None)
+
+        self.p_choose_layer = p_choose_layer
+
+        ffn_layer_norm = create_standard_layer_norm(
+            self.model_dim, device=device, dtype=dtype
+        )
+
+        self.ffn_layer_norm = ffn_layer_norm
+
+        self.ffn = ffn
+
+        if dropout_p > 0.0:
+            self.ffn_dropout = Dropout(dropout_p)
+        else:
+            self.register_module("ffn_dropout", None)
+
+    @finaloverride
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        self_attn_mask: Optional[AttentionMask] = None,
+        encoder_output: Optional[Tensor] = None,
+        encoder_padding_mask: Optional[PaddingMask] = None,
+        *,
+        state_bag: Optional[IncrementalStateBag] = None,
+    ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
+        seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask, state_bag)
+
+        seqs, p_choose = self._forward_encoder_decoder_attn(
+            seqs, padding_mask, encoder_output, encoder_padding_mask
+        )
+
+        seqs = self._forward_ffn(seqs)
+
+        return seqs, padding_mask, p_choose
+
+    def _forward_self_attn(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        self_attn_mask: Optional[AttentionMask],
+        state_bag: Optional[IncrementalStateBag],
+    ) -> Tensor:
+        residual = seqs
+
+        seqs = self.self_attn_layer_norm(seqs)
+
+        seqs = self.self_attn(
+            seqs,
+            padding_mask,
+            keys=seqs,
+            key_padding_mask=padding_mask,
+            values=seqs,
+            attn_mask=self_attn_mask,
+            state_bag=state_bag,
+        )
+
+        if self.self_attn_dropout is not None:
+            seqs = self.self_attn_dropout(seqs)
+
+        seqs = seqs + residual
+
+        return seqs
+
+    def _forward_encoder_decoder_attn(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        encoder_output: Optional[Tensor],
+        encoder_padding_mask: Optional[PaddingMask],
+    ) -> Tuple[Tensor, Tensor]:
+        if encoder_output is None:
+            raise ValueError(
+                "`encoder_output` must not be `None` for encoder-decoder attention."
+            )
+
+        residual = seqs
+
+        seqs = self.encoder_decoder_attn_layer_norm(seqs)
+
+        p_choose = self.p_choose_layer(seqs, encoder_output)
+
+        seqs = self.encoder_decoder_attn(
+            seqs,
+            padding_mask,
+            encoder_output,
+            encoder_padding_mask,
+            encoder_output,
+        )
+
+        if self.encoder_decoder_attn_dropout is not None:
+            seqs = self.encoder_decoder_attn_dropout(seqs)
+
+        seqs = seqs + residual
+
+        return seqs, p_choose
+
+    def _forward_ffn(self, seqs: Tensor) -> Tensor:
+        residual = seqs
+
+        seqs = self.ffn_layer_norm(seqs)
+
+        seqs = self.ffn(seqs)
+
+        if self.ffn_dropout is not None:
+            seqs = self.ffn_dropout(seqs)
+
+        seqs = seqs + residual
+
+        return seqs

+ 148 - 0
src/seamless_communication/models/monotonic_decoder/p_choose.py

@@ -0,0 +1,148 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, final
+from torch import Tensor
+from torch.nn import AvgPool1d, Module, ModuleList, ReLU
+from torch.nn.parameter import Parameter
+import torch
+
+from fairseq2.nn.projection import Linear
+from fairseq2.typing import DataType, Device, finaloverride
+
+
+class EnergyProjection(Module):
+    def __init__(
+        self,
+        model_dim: int,
+        num_layers: int,
+        bias: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        super().__init__()
+
+        if num_layers < 1:
+            raise ValueError(
+                f"Invalid `num_layers`: {num_layers} for EnergyProjectionLayer."
+            )
+
+        self.layers = ModuleList()
+
+        for _ in range(num_layers):
+            self.layers.append(
+                Linear(model_dim, model_dim, bias, device=device, dtype=dtype)
+            )
+            self.layers.append(ReLU())
+
+    def forward(self, seqs: Tensor) -> Tensor:
+        for layer in self.layers:
+            seqs = layer(seqs)
+        return seqs
+
+
+@final
+class PChooseLayer(Module):
+    """Represents a PChoose layer."""
+
+    model_dim: int
+    num_heads: int
+    energy_bias: Parameter
+    monotonic_temperature: float
+    q_energy_proj: EnergyProjection
+    k_energy_proj: EnergyProjection
+    keys_pooling: AvgPool1d
+
+    def __init__(
+        self,
+        model_dim: int,
+        num_heads: int,
+        energy_bias_value: float,
+        monotonic_temperature: float,
+        num_monotonic_energy_layers: int,
+        pre_decision_ratio: int,
+        *,
+        bias: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param model_dim:
+            The dimensionality of the model.
+        :param num_heads:
+            The number of attention heads.
+        :param bias:
+            If ``True``, query, key energy projection layers learn an
+            additive bias.
+        """
+        super().__init__()
+
+        self.model_dim = model_dim
+        self.num_heads = num_heads
+
+        if energy_bias_value != 0.0:
+            self.energy_bias = Parameter(
+                torch.full([1], energy_bias_value, device=device, dtype=dtype)
+            )
+        else:
+            self.register_module("energy_bias", None)
+
+        self.monotonic_temperature = monotonic_temperature
+
+        if num_monotonic_energy_layers <= 0:
+            raise ValueError("Number of monotonic energy layers must be > 0.")
+
+        self.q_energy_proj = EnergyProjection(
+            self.model_dim,
+            num_monotonic_energy_layers,
+            bias,
+            device=device,
+            dtype=dtype,
+        )
+        self.k_energy_proj = EnergyProjection(
+            self.model_dim,
+            num_monotonic_energy_layers,
+            bias,
+            device=device,
+            dtype=dtype,
+        )
+
+        self.keys_pooling = AvgPool1d(
+            kernel_size=pre_decision_ratio,
+            stride=pre_decision_ratio,
+            ceil_mode=True,
+        )
+
+    @finaloverride
+    def forward(self, seqs: Tensor, keys: Tensor) -> Tensor:
+        q = self.q_energy_proj(seqs)
+
+        # (N, S, M) -> (N, H, S, K)
+        q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
+
+        # (N, S_kv, M) -> (N, M, S_kv) -> (N, M, S_p)
+        pooled_keys = self.keys_pooling(keys.transpose(1, 2))
+
+        # (N, M, S_p) -> (N, S_p, M)
+        pooled_keys = pooled_keys.transpose(1, 2)
+
+        k = self.k_energy_proj(pooled_keys)
+
+        # (N, S_p, M) -> (N, H, S_p, K)
+        k = k.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
+
+        # (N, H, S, K) @ (N, H, K, S_p) = (N, H, S, S_p)
+        monotonic_energy = torch.matmul(q, k.transpose(-1, -2))
+
+        monotonic_energy = monotonic_energy * (q.size(-1) ** -0.5)
+
+        if self.energy_bias is not None:
+            monotonic_energy += self.energy_bias
+
+        # p_choose: (N, H, S, S_p)
+        p_choose = torch.sigmoid(monotonic_energy / self.monotonic_temperature)
+
+        return p_choose

+ 8 - 4
src/seamless_communication/models/unity/builder.py

@@ -292,30 +292,34 @@ class UnitYBuilder:
 
     def build_model(self) -> UnitYModel:
         """Build a model."""
-        text_embed = self.mt_model_builder.build_embedding()
-
         speech_encoder_frontend = self.w2v2_encoder_builder.build_frontend()
         speech_encoder = self.build_speech_encoder()
 
         if self.config.use_text_encoder:
+            text_embed = self.mt_model_builder.build_embedding()
             text_encoder_frontend = self.mt_model_builder.build_frontend(text_embed)
             text_encoder = self.mt_model_builder.build_encoder()
         else:
+            text_embed = None
             text_encoder_frontend = None
             text_encoder = None
 
         if self.config.use_text_decoder:
+            if text_embed is None:
+                text_embed = self.mt_model_builder.build_embedding()
+
             if text_encoder_frontend is not None:
                 # We use shared embedding as in NLLB.
                 text_decoder_frontend = text_encoder_frontend
             else:
                 text_decoder_frontend = self.mt_model_builder.build_frontend(text_embed)
+
             text_decoder = self.mt_model_builder.build_decoder()
+            final_proj = TiedProjection(text_embed.weight, bias=None)
         else:
             text_decoder_frontend = None
             text_decoder = None
-
-        final_proj = TiedProjection(text_embed.weight, bias=None)
+            final_proj = None
 
         if self.t2u_builder is None:
             t2u_model = None

+ 41 - 32
src/seamless_communication/models/unity/loader.py

@@ -73,23 +73,30 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             keys_to_delete.append("text_encoder.version")
             keys_to_delete.append("text_encoder.embed_positions._float_tensor")
 
+        if not config.use_text_decoder:
+            text_decoder_keys = [
+                key for key in state_dict if key.startswith(decoder_key)
+            ]
+            keys_to_delete.extend(text_decoder_keys)
+
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
 
-        keys_to_delete.append(
-            f"{t2u_decoder_key}.char_upsampler.embed_positions._float_tensor"
-        )
-        keys_to_delete.append(
-            f"{t2u_decoder_key}.char_upsampler.embed_tokens_char.weight"
-        )
+        if config.prosody_encoder_config is not None or config.t2u_config is not None:
+            keys_to_delete.append(
+                f"{t2u_decoder_key}.char_upsampler.embed_positions._float_tensor"
+            )
+            keys_to_delete.append(
+                f"{t2u_decoder_key}.char_upsampler.embed_tokens_char.weight"
+            )
 
-        # Delete AlignmentEncoder keys for inference.
-        alignment_encoder_keys = [
-            key
-            for key in state_dict
-            if key.startswith(f"{t2u_decoder_key}.alignment_encoder.")
-        ]
-        keys_to_delete.extend(alignment_encoder_keys)
+            # Delete AlignmentEncoder keys for inference.
+            alignment_encoder_keys = [
+                key
+                for key in state_dict
+                if key.startswith(f"{t2u_decoder_key}.alignment_encoder.")
+            ]
+            keys_to_delete.extend(alignment_encoder_keys)
 
         # Delete character-level projection for inference.
         keys_to_delete.extend(
@@ -114,23 +121,31 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             if key in state_dict:
                 del state_dict[key]
 
-        embeds = state_dict["final_proj.weight"]
+        if config.use_text_decoder:
+            embeds = state_dict["final_proj.weight"]
+
+            # fairseq had a bug that accidentally introduced a dummy token in the
+            # embedding table of NLLB-100. We just discard it.
+            if (
+                isinstance(config.mt_model_config, NllbConfig)
+                and embeds.size(0) == 256103
+            ):  # means NLLB-100
+                embeds = embeds[:-1]
 
-        # fairseq had a bug that accidentally introduced a dummy token in the
-        # embedding table of NLLB-100. We just discard it.
-        if (
-            isinstance(config.mt_model_config, NllbConfig) and embeds.size(0) == 256103
-        ):  # means NLLB-100
-            embeds = embeds[:-1]
+                state_dict["final_proj.weight"] = embeds
 
-            state_dict["final_proj.weight"] = embeds
+            # fairseq checkpoints have duplicate embedding weights. Ensure that we
+            # use a single embedding table in fairseq2.
+            state_dict["text_decoder_frontend.embed.weight"] = embeds
 
-        # fairseq checkpoints have duplicate embedding weights. Ensure that we
-        # use a single embedding table in fairseq2.
-        state_dict["text_decoder_frontend.embed.weight"] = embeds
+            if config.use_text_encoder:
+                state_dict["text_encoder_frontend.embed.weight"] = embeds
 
-        if config.use_text_encoder:
-            state_dict["text_encoder_frontend.embed.weight"] = embeds
+            # The embedding positions of the control symbols in fairseq's dict do
+            # not match the SentencePiece model of the tokenizer.
+            with torch.inference_mode():
+                # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
+                embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
 
         char_embeds = state_dict.get(
             "t2u_model.decoder_frontend.embed_char.weight", None
@@ -140,12 +155,6 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             vocab_size = len(index_mapping)
             char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
 
-        # The embedding positions of the control symbols in fairseq's dict do
-        # not match the SentencePiece model of the tokenizer.
-        with torch.inference_mode():
-            # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
-            embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
-
         if config.t2u_config is not None:
             # fairseq checkpoints have duplicate embedding weights. Ensure that we
             # use a single embedding table in fairseq2.

+ 9 - 4
src/seamless_communication/models/unity/model.py

@@ -41,7 +41,7 @@ class UnitYModel(EncoderDecoderModel):
     text_encoder: Optional[TransformerEncoder]
     text_decoder_frontend: Optional[TransformerFrontend]
     text_decoder: Optional[TransformerDecoder]
-    final_proj: Projection
+    final_proj: Optional[Projection]
     t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
     prosody_encoder_model: Optional[ECAPA_TDNN]
 
@@ -53,7 +53,7 @@ class UnitYModel(EncoderDecoderModel):
         text_encoder: Optional[TransformerEncoder],
         text_decoder_frontend: Optional[TransformerFrontend],
         text_decoder: Optional[TransformerDecoder],
-        final_proj: Projection,
+        final_proj: Optional[Projection],
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
         target_vocab_info: VocabularyInfo,
         prosody_encoder_model: Optional[ECAPA_TDNN] = None,
@@ -93,6 +93,7 @@ class UnitYModel(EncoderDecoderModel):
 
             self.text_decoder_frontend = text_decoder_frontend
             self.text_decoder = text_decoder
+            self.final_proj = final_proj
         else:
             if text_decoder_frontend is not None:
                 raise ValueError(
@@ -101,8 +102,7 @@ class UnitYModel(EncoderDecoderModel):
 
             self.register_module("text_decoder_frontend", None)
             self.register_module("text_decoder", None)
-
-        self.final_proj = final_proj
+            self.register_module("final_proj", None)
 
         if t2u_model is not None:
             self.t2u_model = t2u_model
@@ -183,6 +183,11 @@ class UnitYModel(EncoderDecoderModel):
     def project(
         self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask]
     ) -> SequenceModelOutput:
+        if self.final_proj is None:
+            raise ValueError(
+                "`project()` requires a final_proj layer, but the current UnitY model does not have one."
+            )
+
         logits = self.final_proj(decoder_output)
 
         return SequenceModelOutput(logits, self.target_vocab_info)