Prechádzať zdrojové kódy

Merge pull request #28 from fairinternal/nar_t2u_offline

Implementing UnitYNART2UModel with a non-autoregressive T2U decoder for UnitY.
Kaushik Ram Sadagopan 2 rokov pred
rodič
commit
4d640ab465

+ 13 - 0
src/seamless_communication/assets/cards/unity_nar_multilingual.yaml

@@ -0,0 +1,13 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: unity_nar_multilingual
+base: unity_nllb-100
+model_arch: nar_multilingual
+char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file://large_experiments/seamless/ust/lpw/M4T_UNITY2/ckpt/checkpoint_9_80000.pt"
+num_units: 10000
+unit_langs: [arb, ben, hin, ind, ita, jpn, por, rus, swh, tha, tur, urd, vie, spa, eng]

+ 10 - 1
src/seamless_communication/models/inference/translator.py

@@ -26,6 +26,7 @@ from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitYGenerator,
     UnitYModel,
+    UnitYT2UModel,
     load_unity_model,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
@@ -242,6 +243,14 @@ class Translator(nn.Module):
         if output_modality == Modality.TEXT:
             return text_out.sentences[0], None, None
         else:
-            units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
+            if isinstance(self.model.t2u_model, UnitYT2UModel):
+                # Remove the lang token for AR UnitY since the vocoder doesn't need it
+                # in the unit sequence. tgt_lang is fed as an argument to the vocoder.
+                units = unit_out.units[:, 1:]
+            else:
+                units = unit_out.units
+
+            # TODO: batch_size set to 1 for now, implement batching.
+            units = units[0].cpu().numpy().tolist()
             wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
             return text_out.sentences[0], wav_out, sample_rate

+ 42 - 10
src/seamless_communication/models/unity/__init__.py

@@ -6,21 +6,28 @@
 
 from seamless_communication.models.unity.builder import UnitYBuilder as UnitYBuilder
 from seamless_communication.models.unity.builder import UnitYConfig as UnitYConfig
-from seamless_communication.models.unity.builder import (
-    UnitYT2UBuilder as UnitYT2UBuilder,
-)
-from seamless_communication.models.unity.builder import UnitYT2UConfig as UnitYT2UConfig
 from seamless_communication.models.unity.builder import (
     create_unity_model as create_unity_model,
 )
-from seamless_communication.models.unity.builder import (
-    create_unity_t2u_model as create_unity_t2u_model,
-)
 from seamless_communication.models.unity.builder import unity_arch as unity_arch
 from seamless_communication.models.unity.builder import unity_archs as unity_archs
-from seamless_communication.models.unity.builder import unity_t2u_arch as unity_t2u_arch
-from seamless_communication.models.unity.builder import (
-    unity_t2u_archs as unity_t2u_archs,
+from seamless_communication.models.unity.char_tokenizer import (
+    CharTokenizer as CharTokenizer,
+)
+from seamless_communication.models.unity.char_tokenizer import (
+    UnitYCharTokenizerLoader as UnitYCharTokenizerLoader,
+)
+from seamless_communication.models.unity.char_tokenizer import (
+    load_unity_char_tokenizer as load_unity_char_tokenizer,
+)
+from seamless_communication.models.unity.length_regulator import (
+    HardUpsampling as HardUpsampling,
+)
+from seamless_communication.models.unity.length_regulator import (
+    VariancePredictor as VariancePredictor,
+)
+from seamless_communication.models.unity.length_regulator import (
+    VarianceAdaptor as VarianceAdaptor,
 )
 from seamless_communication.models.unity.loader import UnitYLoader as UnitYLoader
 from seamless_communication.models.unity.loader import (
@@ -34,7 +41,32 @@ from seamless_communication.models.unity.loader import (
 )
 from seamless_communication.models.unity.model import UnitYModel as UnitYModel
 from seamless_communication.models.unity.model import UnitYX2TModel as UnitYX2TModel
+from seamless_communication.models.unity.model import UnitYT2UModel as UnitYT2UModel
+from seamless_communication.models.unity.model import (
+    UnitYNART2UModel as UnitYNART2UModel,
+)
 from seamless_communication.models.unity.model import UnitYOutput as UnitYOutput
+from seamless_communication.models.unity.nar_decoder_frontend import (
+    NARDecoderFrontend as NARDecoderFrontend,
+)
+from seamless_communication.models.unity.nar_decoder_layer import (
+    NARTransformerDecoderLayer as NARTransformerDecoderLayer,
+)
+from seamless_communication.models.unity.t2u_builder import (
+    UnitYT2UBuilder as UnitYT2UBuilder,
+)
+from seamless_communication.models.unity.t2u_builder import (
+    UnitYT2UConfig as UnitYT2UConfig,
+)
+from seamless_communication.models.unity.t2u_builder import (
+    create_unity_t2u_model as create_unity_t2u_model,
+)
+from seamless_communication.models.unity.t2u_builder import (
+    unity_t2u_arch as unity_t2u_arch,
+)
+from seamless_communication.models.unity.t2u_builder import (
+    unity_t2u_archs as unity_t2u_archs,
+)
 from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenDecoder as UnitTokenDecoder,
 )

+ 2 - 2
src/seamless_communication/models/unity/adaptor_block.py

@@ -1,4 +1,4 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright (c) Meta Platforms, Inc. and affiliates
 # All rights reserved.
 #
 # This source code is licensed under the license found in the
@@ -365,7 +365,7 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
         self.stride = stride
 
         if layer_norm:
-            self.layer_norm = layer_norm_fn(self.model_dim, device, dtype)
+            self.layer_norm = layer_norm_fn(self.model_dim, device=device, dtype=dtype)
         else:
             self.register_module("layer_norm", None)
 

+ 67 - 304
src/seamless_communication/models/unity/builder.py

@@ -1,4 +1,4 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright (c) Meta Platforms, Inc. and affiliates
 # All rights reserved.
 #
 # This source code is licensed under the license found in the
@@ -7,44 +7,36 @@
 from dataclasses import dataclass
 from typing import Optional
 
-from fairseq2.data import VocabularyInfo
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
-from fairseq2.models.transformer import (
-    TransformerEmbeddingFrontend,
-    TransformerFrontend,
-)
-from seamless_communication.models.unity.adaptor_block import (
-    UnitYConformerAdaptorLayer,
-    UnitYEncoderAdaptor,
-    UnitYTransformerAdaptorLayer,
-)
-from seamless_communication.models.unity.model import UnitYModel, UnitYT2UModel
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
-from fairseq2.nn.embedding import Embedding
-from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.transformer import (
-    FeedForwardNetwork,
     MultiheadAttention,
     StandardFeedForwardNetwork,
     StandardMultiheadAttention,
-    StandardTransformerDecoder,
-    StandardTransformerDecoderLayer,
-    StandardTransformerEncoder,
-    StandardTransformerEncoderLayer,
-    TransformerDecoder,
-    TransformerDecoderLayer,
     TransformerEncoder,
     TransformerEncoderLayer,
-    TransformerNormOrder,
     create_default_sdpa,
 )
 from fairseq2.typing import DataType, Device
 
 
+from seamless_communication.models.unity.adaptor_block import (
+    UnitYConformerAdaptorLayer,
+    UnitYEncoderAdaptor,
+    UnitYTransformerAdaptorLayer,
+)
+from seamless_communication.models.unity.model import UnitYModel
+from seamless_communication.models.unity.t2u_builder import (
+    UnitYT2UBuilder,
+    UnitYT2UConfig,
+    unity_t2u_archs,
+)
+
+
 @dataclass
 class UnitYConfig:
     """Holds the configuration of a UnitY model as described in
@@ -56,14 +48,14 @@ class UnitYConfig:
     w2v2_encoder_config: Wav2Vec2EncoderConfig
     """The configuration of the underlying wav2vec 2.0 encoder."""
 
-    nllb_config: NllbConfig
-    """The configuration of the underlying NLLB text encoder-decoder."""
+    mt_model_config: NllbConfig
+    """The configuration of the underlying MT text encoder-decoder."""
 
-    t2u_config: Optional["UnitYT2UConfig"]
+    t2u_config: Optional[UnitYT2UConfig]
     """The configuration of the UnitY T2U sub-model."""
 
     use_text_encoder: bool
-    """If ``True``, uses an aligned NLLB encoder for the MT task."""
+    """If ``True``, uses an aligned MT encoder for the MT task."""
 
     use_conformer_adaptor: bool
     """If ``True``, uses a Conformer-based adaptor block."""
@@ -95,16 +87,16 @@ unity_arch = unity_archs.marker
 def _base() -> UnitYConfig:
     w2vbert_config = w2vbert_archs.get_config("600m")
 
-    nllb_config = nllb_archs.get_config("dense_1b")
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
-    nllb_config.vocabulary_size = 256102  # NLLB-100
+    mt_model_config.vocabulary_size = 256102  # NLLB-100
 
     t2u_config = unity_t2u_archs.get_config("base")
 
     return UnitYConfig(
         model_dim=1024,
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
-        nllb_config=nllb_config,
+        mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         use_text_encoder=True,
         use_conformer_adaptor=False,
@@ -120,16 +112,16 @@ def _base() -> UnitYConfig:
 def _medium() -> UnitYConfig:
     w2vbert_config = w2vbert_archs.get_config("300m")
 
-    nllb_config = nllb_archs.get_config("dense_600m")
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_600m")
 
-    nllb_config.vocabulary_size = 256206  # NLLB-200
+    mt_model_config.vocabulary_size = 256206  # NLLB-200
 
     t2u_config = unity_t2u_archs.get_config("medium")
 
     return UnitYConfig(
         model_dim=1024,
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
-        nllb_config=nllb_config,
+        mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         use_text_encoder=True,
         use_conformer_adaptor=False,
@@ -141,6 +133,35 @@ def _medium() -> UnitYConfig:
     )
 
 
+@unity_arch("nar_multilingual")
+def _nar_multilingual() -> UnitYConfig:
+    w2vbert_config = w2vbert_archs.get_config("600m")
+    w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
+    w2v2_encoder_config.pos_encoder_depth = 1
+    w2v2_encoder_config.pos_conv_kernel_size = 128
+    w2v2_encoder_config.num_pos_conv_groups = 16
+
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
+
+    mt_model_config.vocabulary_size = 256102  # NLLB-100
+
+    t2u_config = unity_t2u_archs.get_config("nar_multilingual")
+
+    return UnitYConfig(
+        model_dim=1024,
+        w2v2_encoder_config=w2v2_encoder_config,
+        mt_model_config=mt_model_config,
+        t2u_config=t2u_config,
+        use_text_encoder=False,
+        use_conformer_adaptor=False,
+        num_adaptor_layers=1,
+        adaptor_kernel_size=8,
+        adaptor_stride=8,
+        adaptor_layer_norm=True,
+        adaptor_dropout_p=0.1,
+    )
+
+
 class UnitYBuilder:
     """Builds modules of a UnitY model.
 
@@ -150,7 +171,7 @@ class UnitYBuilder:
 
     config: UnitYConfig
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
-    nllb_builder: NllbBuilder
+    mt_model_builder: NllbBuilder
     t2u_builder: Optional["UnitYT2UBuilder"]
     device: Optional[Device]
     dtype: Optional[DataType]
@@ -159,7 +180,7 @@ class UnitYBuilder:
         self,
         config: UnitYConfig,
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
-        nllb_builder: NllbBuilder,
+        mt_model_builder: NllbBuilder,
         t2u_builder: Optional["UnitYT2UBuilder"],
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
@@ -169,8 +190,8 @@ class UnitYBuilder:
             The configuration to use.
         :param w2v2_encoder_builder:
             The wav2vec 2.0 encoder builder.
-        :param nllb_builder:
-            The NLLB model builder.
+        :param mt_model_builder:
+            The MT model builder.
         :param t2u_builder:
             The UnitY T2U model builder.
         :param device:
@@ -183,9 +204,9 @@ class UnitYBuilder:
                 f"`model_dim` and `model_dim` of `w2v2_encoder_builder.config` must be equal, but are {config.model_dim} and {w2v2_encoder_builder.config.model_dim} instead."
             )
 
-        if nllb_builder.config.model_dim != config.model_dim:
+        if mt_model_builder.config.model_dim != config.model_dim:
             raise ValueError(
-                f"`model_dim` and `model_dim` of `nllb_builder.config` must be equal, but are {config.model_dim} and {nllb_builder.config.model_dim} instead."
+                f"`model_dim` and `model_dim` of `mt_model_builder.config` must be equal, but are {config.model_dim} and {mt_model_builder.config.model_dim} instead."
             )
 
         if t2u_builder is not None and t2u_builder.config.model_dim != config.model_dim:
@@ -195,25 +216,25 @@ class UnitYBuilder:
 
         self.config = config
         self.w2v2_encoder_builder = w2v2_encoder_builder
-        self.nllb_builder = nllb_builder
+        self.mt_model_builder = mt_model_builder
         self.t2u_builder = t2u_builder
         self.device = device
         self.dtype = dtype
 
     def build_model(self) -> UnitYModel:
         """Build a model."""
-        text_embed = self.nllb_builder.build_embedding()
+        text_embed = self.mt_model_builder.build_embedding()
 
         speech_encoder_frontend = self.w2v2_encoder_builder.build_frontend()
         speech_encoder = self.build_speech_encoder()
 
-        text_decoder_frontend = self.nllb_builder.build_frontend(text_embed)
-        text_decoder = self.nllb_builder.build_decoder()
+        text_decoder_frontend = self.mt_model_builder.build_frontend(text_embed)
+        text_decoder = self.mt_model_builder.build_decoder()
 
         if self.config.use_text_encoder:
             # We use shared embedding as in NLLB.
             text_encoder_frontend = text_decoder_frontend
-            text_encoder = self.nllb_builder.build_encoder()
+            text_encoder = self.mt_model_builder.build_encoder()
         else:
             text_encoder_frontend = None
             text_encoder = None
@@ -234,7 +255,7 @@ class UnitYBuilder:
             text_decoder,
             final_proj,
             t2u_model,
-            self.config.nllb_config.pad_idx,
+            self.config.mt_model_config.pad_idx,
         )
 
     def build_speech_encoder(self) -> TransformerEncoder:
@@ -357,277 +378,19 @@ def create_unity_model(
         config.w2v2_encoder_config, device=device, dtype=dtype
     )
 
-    nllb_builder = NllbBuilder(config.nllb_config, device=device, dtype=dtype)
-
     if config.t2u_config is None:
         t2u_builder = None
     else:
         t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
+    mt_model_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
     unity_builder = UnitYBuilder(
         config,
         w2v2_encoder_builder,
-        nllb_builder,
+        mt_model_builder,
         t2u_builder,
         device=device,
         dtype=dtype,
     )
 
     return unity_builder.build_model()
-
-
-@dataclass
-class UnitYT2UConfig:
-    """Holds the configuration of a UnitY T2U model as described in
-    :cite:t`https://doi.org/10.48550/arxiv.2212.08055`"""
-
-    model_dim: int
-    """The dimensionality of the model."""
-
-    unit_max_seq_len: int
-    """The expected maximum unit sequence length."""
-
-    unit_vocabulary_size: int
-    """The size of the unit vocabulary."""
-
-    unit_pad_idx: Optional[int]
-    """The index of the pad symbol in the unit vocabulary."""
-
-    num_encoder_layers: int
-    """The number of Transformer encoder layers."""
-
-    num_decoder_layers: int
-    """The number of Transformer decoder layers."""
-
-    num_encoder_attn_heads: int
-    """The number of attention heads in Transformer encoder layers."""
-
-    num_decoder_attn_heads: int
-    """The number of attention heads in Transformer decoder layers."""
-
-    ffn_inner_dim: int
-    """The inner dimensionality of Transformer feed-forward networks."""
-
-    dropout_p: float
-    """The dropout probability in Transformer layers."""
-
-    def update_unit_vocabulary(self, info: VocabularyInfo) -> None:
-        """Update unit vocabulary configuration from ``info``."""
-        self.unit_vocabulary_size, self.unit_pad_idx = info.size, info.pad_idx
-
-
-unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
-
-
-unity_t2u_arch = unity_t2u_archs.marker
-
-
-@unity_t2u_arch("base")
-def _base_t2u() -> UnitYT2UConfig:
-    return UnitYT2UConfig(
-        model_dim=1024,
-        unit_max_seq_len=2048,
-        unit_vocabulary_size=10082,
-        unit_pad_idx=1,
-        num_encoder_layers=6,
-        num_decoder_layers=6,
-        num_encoder_attn_heads=16,
-        num_decoder_attn_heads=16,
-        ffn_inner_dim=1024 * 8,
-        dropout_p=0.1,
-    )
-
-
-@unity_t2u_arch("medium")
-def _medium_t2u() -> UnitYT2UConfig:
-    return UnitYT2UConfig(
-        model_dim=1024,
-        unit_max_seq_len=2048,
-        unit_vocabulary_size=10082,
-        unit_pad_idx=1,
-        num_encoder_layers=4,
-        num_decoder_layers=4,
-        num_encoder_attn_heads=16,
-        num_decoder_attn_heads=16,
-        ffn_inner_dim=1024 * 8,
-        dropout_p=0.1,
-    )
-
-
-class UnitYT2UBuilder:
-    """Builds modules of a UnitY T2U model.
-
-    To tweak the architecture, you can derive from this class and override the
-    corresponding methods.
-    """
-
-    config: UnitYT2UConfig
-    device: Optional[Device]
-    dtype: Optional[DataType]
-
-    def __init__(
-        self,
-        config: UnitYT2UConfig,
-        device: Optional[Device] = None,
-        dtype: Optional[DataType] = None,
-    ) -> None:
-        """
-        :param config:
-            The configuration to use.
-        :param device:
-            The device on which to initialize modules.
-        :param dtype:
-            The data type of module parameters and buffers.
-        """
-        self.config = config
-        self.device = device
-        self.dtype = dtype
-
-    def build_model(self) -> UnitYT2UModel:
-        """Build a model."""
-        embed = self.build_embedding()
-
-        encoder = self.build_encoder()
-
-        decoder_frontend = self.build_decoder_frontend(embed)
-        decoder = self.build_decoder()
-
-        final_proj = TiedProjection(embed.weight, bias=None)
-
-        return UnitYT2UModel(
-            encoder,
-            decoder_frontend,
-            decoder,
-            final_proj,
-            self.config.unit_pad_idx,
-        )
-
-    def build_embedding(self) -> Embedding:
-        """Build a unit embedding table."""
-        return Embedding(
-            num_embeddings=self.config.unit_vocabulary_size,
-            embedding_dim=self.config.model_dim,
-            pad_idx=self.config.unit_pad_idx,
-            scaled=True,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_encoder(self) -> Optional[TransformerEncoder]:
-        """Build a Transformer encoder."""
-        num_layers = self.config.num_encoder_layers
-        if num_layers == 0:
-            return None
-
-        layers = [self.build_encoder_layer() for _ in range(num_layers)]
-
-        return StandardTransformerEncoder(
-            layers,
-            norm_order=TransformerNormOrder.PRE,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_encoder_layer(self) -> TransformerEncoderLayer:
-        """Build a Transformer encoder layer."""
-        self_attn = self.build_attention(self.config.num_encoder_attn_heads)
-
-        ffn = self.build_ffn()
-
-        return StandardTransformerEncoderLayer(
-            self_attn,
-            ffn,
-            dropout_p=self.config.dropout_p,
-            norm_order=TransformerNormOrder.PRE,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_decoder_frontend(self, embed: Embedding) -> TransformerFrontend:
-        """Build a Transformer decoder front-end."""
-        pos_encoder = SinusoidalPositionEncoder(
-            self.config.model_dim,
-            self.config.unit_max_seq_len,
-            _legacy_pad_idx=self.config.unit_pad_idx,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-        return TransformerEmbeddingFrontend(
-            embed,
-            pos_encoder,
-            dropout_p=self.config.dropout_p,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_decoder(self) -> TransformerDecoder:
-        """Build a Transformer decoder."""
-        num_layers = self.config.num_decoder_layers
-
-        layers = [self.build_decoder_layer() for _ in range(num_layers)]
-
-        return StandardTransformerDecoder(
-            layers,
-            norm_order=TransformerNormOrder.PRE,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_decoder_layer(self) -> TransformerDecoderLayer:
-        """Build a Transformer decoder layer."""
-        self_attn = self.build_attention(self.config.num_decoder_attn_heads)
-
-        encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads)
-
-        ffn = self.build_ffn()
-
-        return StandardTransformerDecoderLayer(
-            self_attn,
-            encoder_decoder_attn,
-            ffn,
-            dropout_p=self.config.dropout_p,
-            norm_order=TransformerNormOrder.PRE,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_attention(self, num_heads: int) -> MultiheadAttention:
-        """Build a Transformer multi-head attention layer."""
-        sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
-
-        return StandardMultiheadAttention(
-            self.config.model_dim,
-            num_heads,
-            sdpa=sdpa,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-    def build_ffn(self) -> FeedForwardNetwork:
-        """Build a Transformer feed-forward network."""
-        return StandardFeedForwardNetwork(
-            self.config.model_dim,
-            self.config.ffn_inner_dim,
-            bias=True,
-            norm_order=TransformerNormOrder.PRE,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-
-def create_unity_t2u_model(
-    config: UnitYT2UConfig,
-    device: Optional[Device] = None,
-    dtype: Optional[DataType] = None,
-) -> UnitYT2UModel:
-    """Create a UnitY T2U model.
-
-    :param config:
-        The configuration to use.
-    :param device:
-        The device on which to initialize modules.
-    :param dtype:
-        The data type of module parameters and buffers.
-    """
-    return UnitYT2UBuilder(config, device, dtype).build_model()

+ 104 - 0
src/seamless_communication/models/unity/char_tokenizer.py

@@ -0,0 +1,104 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Union, final
+
+from fairseq2.assets import AssetStore, AssetDownloadManager, download_manager
+from fairseq2.assets.card import AssetCard
+from fairseq2.data.text import (
+    SentencePieceDecoder,
+    SentencePieceEncoder,
+    SentencePieceModel,
+    TextTokenDecoder,
+    TextTokenEncoder,
+    TextTokenizer,
+    vocabulary_from_sentencepiece,
+)
+from fairseq2.data.typing import PathLike
+from fairseq2.typing import Device, finaloverride
+
+from seamless_communication.assets import asset_store
+
+
+@final
+class CharTokenizer(TextTokenizer):
+    """A character-level tokenizer used during non-autoregressive T2U decoding."""
+
+    model: SentencePieceModel
+
+    def __init__(self, pathname: PathLike) -> None:
+        """
+        :param pathname:
+            The pathname of the SentencePiece model file.
+        """
+        self.model = SentencePieceModel(pathname)
+
+        vocab_info = vocabulary_from_sentencepiece(self.model)
+
+        super().__init__(vocab_info)
+
+    @finaloverride
+    def create_encoder(
+        self,
+        task: Optional[str] = None,
+        lang: Optional[str] = None,
+        mode: Optional[str] = None,
+        device: Optional[Device] = None,
+        pin_memory: bool = False,
+    ) -> TextTokenEncoder:
+        """Creates a character level encoder."""
+        return SentencePieceEncoder(
+            self.model,
+            device=device,
+            pin_memory=pin_memory,
+        )
+
+    @finaloverride
+    def create_decoder(self) -> TextTokenDecoder:
+        return SentencePieceDecoder(self.model)
+
+
+class UnitYCharTokenizerLoader:
+    """Loads character-level tokenizers of UnitY models."""
+
+    def __init__(
+        self, asset_store: AssetStore, download_manager: AssetDownloadManager
+    ) -> None:
+        """
+        :param asset_store:
+            The asset store to retrieve the model information.
+        :param download_manager:
+            The download manager to use.
+        """
+        self.asset_store = asset_store
+        self.download_manager = download_manager
+
+    def __call__(
+        self,
+        model_name_or_card: Union[str, AssetCard],
+        force: bool = False,
+        progress: bool = True,
+    ) -> CharTokenizer:
+        """
+        :param model_name_or_card:
+            The name of the model or an already loaded AssetCard
+        """
+
+        if isinstance(model_name_or_card, AssetCard):
+            card = model_name_or_card
+        else:
+            card = self.asset_store.retrieve_card(model_name_or_card)
+
+        uri = card.field("char_tokenizer").as_uri()
+
+        pathname = self.download_manager.download_tokenizer(
+            uri, card.name, force=force, progress=progress
+        )
+
+        return CharTokenizer(pathname)
+
+
+load_unity_char_tokenizer = UnitYCharTokenizerLoader(asset_store, download_manager)

+ 49 - 41
src/seamless_communication/models/unity/generator.py

@@ -1,4 +1,4 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright (c) Meta Platforms, Inc. and affiliates
 # All rights reserved.
 #
 # This source code is licensed under the license found in the
@@ -16,7 +16,11 @@ from fairseq2.generation import (
     SequenceToTextGenerator,
     SequenceToTextOutput,
 )
-from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
+from seamless_communication.models.unity.model import (
+    UnitYModel,
+    UnitYX2TModel,
+    UnitYT2UModel,
+)
 from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenDecoder,
     UnitTokenizer,
@@ -127,19 +131,20 @@ class UnitYGenerator:
                 lang=target_lang, device=infer_device(model.t2u_model)
             )
 
-            if unit_opts is None:
-                # Speech sequences are typically much longer than text sequences.
-                unit_opts = SequenceGeneratorOptions(
-                    soft_max_seq_len=(1, 50), hard_max_seq_len=5000
+            if isinstance(self.model.t2u_model, UnitYT2UModel):
+                if unit_opts is None:
+                    # Speech sequences are typically much longer than text sequences.
+                    unit_opts = SequenceGeneratorOptions(
+                        soft_max_seq_len=(1, 50), hard_max_seq_len=5000
+                    )
+
+                self.unit_generator = Seq2SeqGenerator(
+                    self.model.t2u_model,
+                    unit_tokenizer.vocab_info,
+                    unit_encoder.prefix_indices,
+                    unit_opts,
                 )
 
-            self.unit_generator = Seq2SeqGenerator(
-                model.t2u_model,
-                unit_tokenizer.vocab_info,
-                unit_encoder.prefix_indices,
-                unit_opts,
-            )
-
     @torch.inference_mode()
     def __call__(
         self,
@@ -186,6 +191,10 @@ class UnitYGenerator:
 
         text_seqs, text_seq_lens = text_output.generator_output.collate()
 
+        # Manually trim the final EOS token to be consistent with fairseq.
+        if text_seq_lens is not None:
+            text_seq_lens -= 1
+
         # Use the output of the text generator to compute the decoder output.
         decoder_output, decoder_padding_mask = self.model.decode(
             text_seqs,
@@ -195,31 +204,41 @@ class UnitYGenerator:
         )
 
         assert self.model.t2u_model is not None
-
-        t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
-            decoder_output, decoder_padding_mask
-        )
-
-        assert self.unit_generator is not None
         assert self.unit_decoder is not None
 
-        unit_gen_output = self.unit_generator(
-            t2u_encoder_output,
-            t2u_encoder_padding_mask,
-            source_seq_len=source_seqs.size(1),
-        )
-
-        unit_seqs, _ = unit_gen_output.collate()
+        unit_gen_output = None
+        if isinstance(self.model.t2u_model, UnitYT2UModel):
+            assert self.unit_generator is not None
+            t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
+                decoder_output, decoder_padding_mask
+            )
+            unit_gen_output = self.unit_generator(
+                t2u_encoder_output,
+                t2u_encoder_padding_mask,
+                source_seq_len=source_seqs.size(1),
+            )
+            unit_seqs, _ = unit_gen_output.collate()
+        else:
+            unit_decoder_output, decoder_padding_mask = self.model.t2u_model(
+                text_decoder_output=decoder_output,
+                text_decoder_padding_mask=decoder_padding_mask,
+                target_seqs=None,
+                target_seq_lens=None,
+                text_seqs=text_seqs,
+            )
+            # (B, S_unit, V_unit)
+            unit_seqs = unit_decoder_output.logits.argmax(dim=2)
+            # Apply the padding mask to the generated units.
+            unit_seqs[decoder_padding_mask == -torch.inf] = unit_decoder_output.pad_idx
 
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)
+
         if ngram_filtering:
             units = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
             units = torch.tensor(units)
 
-        unit_output = SequenceToUnitOutput(
-            units, unit_gen_output, t2u_encoder_output, t2u_encoder_padding_mask
-        )
+        unit_output = SequenceToUnitOutput(units, unit_gen_output)
 
         return text_output, unit_output
 
@@ -229,16 +248,5 @@ class SequenceToUnitOutput:
     units: Tensor
     """The generated units."""
 
-    generator_output: SequenceGeneratorOutput
+    generator_output: Optional[SequenceGeneratorOutput]
     """The output of the underlying :class:`Seq2SeqGenerator`."""
-
-    t2u_encoder_output: Tensor
-    """The encoder output of the underlying UnitY T2U model used to generate the
-    units. *Shape:* :math:`(N,S_{enc},M)`, where :math:`N` is the batch size,
-    :math:`S_{enc}` is the encoder output sequence length, and :math:`M` is the
-    dimensionality of the model."""
-
-    t2u_encoder_padding_mask: Optional[Tensor]
-    """The float padding mask of :attr:`encoder_output`. *Shape:*
-    :math:`(N,S_{enc})`, where :math:`N` is the batch size and :math:`S_{enc}`
-    is the encoder output sequence length."""

+ 194 - 0
src/seamless_communication/models/unity/length_regulator.py

@@ -0,0 +1,194 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from torch import Tensor
+from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
+
+from typing import Optional, Tuple
+
+from fairseq2.typing import DataType, Device
+from fairseq2.nn.transformer import create_default_layer_norm
+from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.projection import Linear
+from fairseq2.nn.utils.mask import apply_padding_mask
+
+
+class HardUpsampling(Module):
+    """Upsamples sequences in a deterministic way as governed by durations."""
+
+    def forward(self, seqs: Tensor, durations: Tensor) -> Tuple[Tensor, Tensor]:
+        # seqs: (N, S, M), durations: (N, S)
+        if durations.dtype not in (torch.int16, torch.int32, torch.int64):
+            raise TypeError("The durations tensor should have an integer dtype.")
+
+        upsampled_seq_lens = durations.sum(dim=1)
+        max_len = int(upsampled_seq_lens.max().item())
+        N, _, M = seqs.shape
+        upsampled_seqs = seqs.new_zeros((N, max_len, M))
+
+        for b in range(N):
+            upsampled_seqs[b, : upsampled_seq_lens[b]] = seqs[b].repeat_interleave(
+                durations[b], dim=0
+            )
+
+        return upsampled_seqs, upsampled_seq_lens
+
+
+class VariancePredictor(Module):
+    """Represents the duration/pitch/energy predictor as described in
+    :cite:t:`https://arxiv.org/pdf/2006.04558.pdf`"""
+
+    conv1: Sequential
+    ln1: LayerNorm
+    dropout_module: Dropout
+    conv2: Sequential
+    ln2: LayerNorm
+    proj: Linear
+
+    def __init__(
+        self,
+        encoder_embed_dim: int,
+        var_pred_hidden_dim: int,
+        var_pred_kernel_size: int,
+        var_pred_dropout: float,
+        bias: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.conv1 = Sequential(
+            Conv1d(
+                encoder_embed_dim,
+                var_pred_hidden_dim,
+                var_pred_kernel_size,
+                stride=1,
+                padding=(var_pred_kernel_size - 1) // 2,
+                bias=bias,
+                device=device,
+                dtype=dtype,
+            ),
+            ReLU(),
+        )
+
+        layer_norm_fn = create_default_layer_norm
+
+        self.ln1 = layer_norm_fn(var_pred_hidden_dim, device=device, dtype=dtype)
+
+        self.dropout_module = Dropout(p=var_pred_dropout)
+
+        self.conv2 = Sequential(
+            Conv1d(
+                var_pred_hidden_dim,
+                var_pred_hidden_dim,
+                var_pred_kernel_size,
+                stride=1,
+                padding=1,
+                bias=bias,
+                device=device,
+                dtype=dtype,
+            ),
+            ReLU(),
+        )
+
+        self.ln2 = layer_norm_fn(var_pred_hidden_dim, device=device, dtype=dtype)
+
+        self.proj = Linear(
+            var_pred_hidden_dim, 1, bias=True, device=device, dtype=dtype
+        )
+
+    def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+        # Ensure that we do not leak padded positions in the convolution layer.
+        seqs = apply_padding_mask(seqs, padding_mask)
+
+        # (N, S, M) -> (N, M, S)
+        seqs = seqs.transpose(1, 2)
+
+        # (N, M, S) -> (N, H, S)
+        seqs = self.conv1(seqs)
+
+        # (N, H, S) -> (N, S, H)
+        seqs = seqs.transpose(1, 2)
+
+        seqs = self.ln1(seqs)
+
+        seqs = self.dropout_module(seqs)
+
+        seqs = apply_padding_mask(seqs, padding_mask)
+
+        # (N, S, H) -> (N, H, S)
+        seqs = seqs.transpose(1, 2)
+
+        # (N, H, S) -> (N, H, S)
+        seqs = self.conv2(seqs)
+
+        # (N, H, S) -> (N, S, H)
+        seqs = seqs.transpose(1, 2)
+
+        seqs = self.ln2(seqs)
+
+        seqs = self.dropout_module(seqs)
+
+        # (N, S, H) -> (N, S, 1) -> (N, S)
+        seqs = self.proj(seqs).squeeze(dim=2)
+
+        return seqs
+
+
+class VarianceAdaptor(Module):
+    """Represent the Variance adaptor as described in
+    :cite:t:`https://arxiv.org/pdf/2006.04558.pdf`"""
+
+    duration_predictor: VariancePredictor
+    pitch_predictor: Optional[VariancePredictor]
+    energy_predictor: Optional[VariancePredictor]
+    hard_upsampling: HardUpsampling
+
+    def __init__(
+        self,
+        duration_predictor: VariancePredictor,
+        pitch_predictor: Optional[VariancePredictor] = None,
+        energy_predictor: Optional[VariancePredictor] = None,
+    ):
+        super().__init__()
+
+        self.duration_predictor = duration_predictor
+
+        if pitch_predictor:
+            self.pitch_predictor = pitch_predictor
+        else:
+            self.register_module("pitch_predictor", None)
+
+        if energy_predictor:
+            self.energy_predictor = energy_predictor
+        else:
+            self.register_module("energy_predictor", None)
+
+        self.hard_upsampling = HardUpsampling()
+
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+        duration_factor: float = 1.0,
+        min_duration: int = 0,
+    ) -> Tuple[Tensor, Tensor]:
+        log_durations = self.duration_predictor(seqs, padding_mask)
+
+        durations = torch.clamp(
+            torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),
+            min=min_duration,
+        )
+
+        # We need to apply the padding_mask again since we clamp by min_duration.
+        durations = apply_padding_mask(durations, padding_mask)
+
+        # TODO: Implement pitch, energy predictors.
+        # TODO: Implement GaussianUpsampling.
+        seqs, seq_lens = self.hard_upsampling(seqs, durations)
+
+        return seqs, seq_lens

+ 51 - 5
src/seamless_communication/models/unity/loader.py

@@ -1,4 +1,4 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright (c) Meta Platforms, Inc. and affiliates
 # All rights reserved.
 #
 # This source code is licensed under the license found in the
@@ -6,9 +6,11 @@
 
 from typing import Any, Dict, Mapping, Union, final
 
+import numpy as np
 import torch
 from fairseq2.assets import AssetStore, download_manager
 from fairseq2.assets.card import AssetCard
+from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
 from seamless_communication.models.unity.builder import (
     UnitYConfig,
@@ -57,11 +59,28 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         del state_dict["encoder.w2v_encoder.w2v_model.mask_emb"]
 
+        # Delete AlignmentEncoder keys for inference.
+        alignment_encoder_keys = [
+            key for key in state_dict if key.startswith("decoder.alignment_encoder.")
+        ]
+        for key in alignment_encoder_keys:
+            del state_dict[key]
+
+        # Delete character-level projection for inference.
+        for key in [
+            "decoder_target_letter_decoder.proj.weight",
+            "decoder_target_letter_decoder.proj.bias",
+        ]:
+            if key in state_dict:
+                del state_dict[key]
+
         embeds = state_dict["final_proj.weight"]
 
         # fairseq had a bug that accidentally introduced a dummy token in the
         # embedding table of NLLB-100. We just discard it.
-        if embeds.size(0) == 256103:  # means NLLB-100
+        if (
+            isinstance(config.mt_model_config, NllbConfig) and embeds.size(0) == 256103
+        ):  # means NLLB-100
             embeds = embeds[:-1]
 
             state_dict["final_proj.weight"] = embeds
@@ -73,6 +92,15 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         if config.use_text_encoder:
             state_dict["text_encoder_frontend.embed.weight"] = embeds
 
+        # TODO: Remove this hack once we get the correct char SPM .model file.
+        char_embeds = state_dict.get(
+            "t2u_model.decoder_frontend.embed_char.weight", None
+        )
+        if char_embeds is not None:
+            vocab_size = char_embeds.shape[0]
+            index_mapping = np.load("/checkpoint/krs/unity2/char_dict_mapping.npy")
+            char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
+
         # The embedding positions of the control symbols in fairseq's dict do
         # not match the SentencePiece model of the tokenizer.
         with torch.inference_mode():
@@ -84,7 +112,8 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             # use a single embedding table in fairseq2.
             embeds = state_dict["t2u_model.final_proj.weight"]
 
-            state_dict["t2u_model.decoder_frontend.embed.weight"] = embeds
+            if "t2u_model.decoder_frontend.embed.weight" in state_dict:
+                state_dict["t2u_model.decoder_frontend.embed.weight"] = embeds
 
         return checkpoint
 
@@ -97,6 +126,10 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    r"speech_encoder_frontend.pos_encoder.conv.",
             r"^encoder\.w2v_encoder\.w2v_model\.layer_norm\.":                                              r"speech_encoder_frontend.post_extract_layer_norm.",
             r"^encoder\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       r"speech_encoder_frontend.model_dim_proj.",
+            r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
+            r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
+            r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
+
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
@@ -157,17 +190,28 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             r"^synthesizer_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
             r"^synthesizer_encoder\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
 
+            # T2U Decoder frontend
+            r"^decoder\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
+            r"^decoder\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
+            r"^decoder\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
+            r"^decoder\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
+            r"^decoder\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
+            r"^decoder\.dec_pos_emb_alpha_char":                        r"t2u_model.decoder_frontend.pos_emb_alpha_char",
+
             # T2U Decoder
-            r"^decoder\.embed_tokens\.":                              r"t2u_model.decoder_frontend.embed.",
             r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
             r"^decoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
             r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+            r"^decoder\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
             r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
             r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
             r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
             r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
             r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
             r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
+            r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
+            r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
+            r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
             r"^decoder\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
             r"^decoder\.output_projection\.":                         r"t2u_model.final_proj.",
             # fmt: on
@@ -267,7 +311,9 @@ class UnitYUnitTokenizerLoader:
             card = self.asset_store.retrieve_card(model_name_or_card)
 
         return UnitTokenizer(
-            card.field("num_units").as_(int), card.field("unit_langs").as_list(str)
+            card.field("num_units").as_(int),
+            card.field("unit_langs").as_list(str),
+            card.field("model_arch").as_(str),
         )
 
 

+ 115 - 9
src/seamless_communication/models/unity/model.py

@@ -1,14 +1,13 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright (c) Meta Platforms, Inc. and affiliates
 # All rights reserved.
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Optional, Tuple, final
+from typing import Optional, Tuple, Union, final
 
 from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
-from fairseq2.models.seq2seq import Seq2SeqBatch
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.transformer.frontend import TransformerFrontend
 from fairseq2.nn.incremental_state import IncrementalStateBag
@@ -19,6 +18,8 @@ from overrides import final as finaloverride
 from torch import Tensor
 from torch.nn import Module
 
+from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
+
 
 @final
 class UnitYModel(EncoderDecoderModel):
@@ -38,7 +39,7 @@ class UnitYModel(EncoderDecoderModel):
     text_decoder_frontend: TransformerFrontend
     text_decoder: TransformerDecoder
     final_proj: Projection
-    t2u_model: Optional["UnitYT2UModel"]
+    t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
     pad_idx: Optional[int]
 
     def __init__(
@@ -50,7 +51,7 @@ class UnitYModel(EncoderDecoderModel):
         text_decoder_frontend: TransformerFrontend,
         text_decoder: TransformerDecoder,
         final_proj: Projection,
-        t2u_model: Optional["UnitYT2UModel"],
+        t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
         pad_idx: Optional[int],
         input_modality: str = "speech",
     ) -> None:
@@ -270,14 +271,20 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 
         self.pad_idx = pad_idx
 
-    def forward(self, batch: Seq2SeqBatch) -> SequenceModelOutput:
+    def forward(
+        self,
+        text_decoder_output: Tensor,
+        text_decoder_padding_mask: Optional[Tensor],
+        target_seqs: Tensor,
+        target_seq_lens: Optional[Tensor],
+    ) -> SequenceModelOutput:
         encoder_output, encoder_padding_mask = self.encode(
-            batch.source_seqs, batch.source_seq_lens
+            text_decoder_output, text_decoder_padding_mask
         )
 
         decoder_output, decoder_padding_mask = self.decode(
-            batch.target_seqs,
-            batch.target_seq_lens,
+            target_seqs,
+            target_seq_lens,
             encoder_output,
             encoder_padding_mask,
         )
@@ -320,6 +327,105 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
         return SequenceModelOutput(logits, self.pad_idx)
 
 
+@final
+class UnitYNART2UModel(Module):
+    """Represents a non-autoregressive UnitY T2U model."""
+
+    model_dim: int
+    encoder: Optional[TransformerEncoder]
+    decoder_frontend: NARDecoderFrontend
+    decoder: TransformerDecoder
+    final_proj: Projection
+    pad_idx: Optional[int]
+
+    def __init__(
+        self,
+        encoder: Optional[TransformerEncoder],
+        decoder_frontend: NARDecoderFrontend,
+        decoder: TransformerDecoder,
+        final_proj: Projection,
+        pad_idx: Optional[int],
+    ) -> None:
+        super().__init__()
+
+        self.model_dim = decoder.model_dim
+
+        if encoder is not None:
+            if encoder.model_dim != self.model_dim:
+                raise ValueError(
+                    f"`model_dim` of `encoder` and `model_dim` of `decoder` must be equal, but are {encoder.model_dim} and {self.model_dim} instead."
+                )
+
+            self.encoder = encoder
+        else:
+            self.register_module("encoder", None)
+
+        if decoder_frontend.model_dim != self.model_dim:
+            raise ValueError(
+                f"`model_dim` of `decoder_frontend` and `model_dim` of `decoder` must be equal, but are {decoder_frontend.model_dim} and {self.model_dim} instead."
+            )
+
+        self.decoder_frontend = decoder_frontend
+        self.decoder = decoder
+
+        self.final_proj = final_proj
+
+        self.pad_idx = pad_idx
+
+    def forward(
+        self,
+        text_decoder_output: Tensor,
+        text_decoder_padding_mask: Optional[Tensor],
+        target_seqs: Optional[Tensor],
+        target_seq_lens: Optional[Tensor],
+        text_seqs: Optional[Tensor],
+    ) -> Tuple[SequenceModelOutput, Optional[Tensor]]:
+        encoder_output, encoder_padding_mask = self.encode(
+            text_decoder_output, text_decoder_padding_mask
+        )
+
+        decoder_output, decoder_padding_mask = self.decode(
+            target_seqs,
+            target_seq_lens,
+            encoder_output,
+            encoder_padding_mask,
+            text_seqs,
+        )
+
+        return self.project(decoder_output), decoder_padding_mask
+
+    def encode(
+        self,
+        text_decoder_output: Tensor,
+        text_decoder_padding_mask: Optional[Tensor],
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        if self.encoder is None:
+            return text_decoder_output, text_decoder_padding_mask
+
+        return self.encoder(text_decoder_output, text_decoder_padding_mask)  # type: ignore[no-any-return]
+
+    def decode(
+        self,
+        seqs: Optional[Tensor],
+        seq_lens: Optional[Tensor],
+        encoder_output: Tensor,
+        encoder_padding_mask: Optional[Tensor],
+        text_seqs: Optional[Tensor],
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        # encoder_output: (N, S, M)
+        # text_seqs: (N, S)
+        seqs, padding_mask = self.decoder_frontend(
+            seqs, seq_lens, encoder_output, encoder_padding_mask, text_seqs
+        )
+
+        return self.decoder(seqs, padding_mask)  # type: ignore[no-any-return]
+
+    def project(self, decoder_output: Tensor) -> SequenceModelOutput:
+        logits = self.final_proj(decoder_output)
+
+        return SequenceModelOutput(logits, self.pad_idx)
+
+
 @dataclass
 class UnitYOutput:
     """Holds the output of a UnitY model."""

+ 336 - 0
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -0,0 +1,336 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple, final
+
+from torch import Tensor
+from torch.nn import Dropout, Module, Parameter
+
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from fairseq2.nn.embedding import Embedding
+from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.position_encoder import PositionEncoder
+from fairseq2.nn.transformer import create_default_layer_norm
+from fairseq2.nn.utils.mask import to_padding_mask
+from fairseq2.typing import DataType, Device, finaloverride
+
+
+from seamless_communication.models.unity.length_regulator import (
+    HardUpsampling,
+    VarianceAdaptor,
+)
+from seamless_communication.models.unity.char_tokenizer import CharTokenizer
+
+import math
+import torch
+
+
+SPACE = "▁"
+
+
+class TagManager:
+    def __init__(self, vocab_info: VocabularyInfo):
+        self.vocab_info = vocab_info
+
+    def preprocess_text_seqs(self, text_seqs: Tensor) -> Tensor:
+        # Remove EOS, lang tokens as per NLLB "target" tokenizer mode.
+        text_seqs = text_seqs[:, 2:]
+        assert self.vocab_info.pad_idx is not None
+        text_seqs.masked_fill_(
+            text_seqs == self.vocab_info.eos_idx, self.vocab_info.pad_idx
+        )
+        return text_seqs
+
+    def postprocess_dur_or_len(self, dur_or_len: Tensor) -> Tensor:
+        N = dur_or_len.shape[0]
+        pad_zero = dur_or_len.new_zeros((N, 1))
+        # Add pads for lang, EOS tokens as per NLLB "source" tokenizer mode.
+        dur_or_len = torch.cat([pad_zero, dur_or_len, pad_zero], dim=1)
+        return dur_or_len
+
+
+@final
+class NARDecoderFrontend(Module):
+    """Represents a Non-autoregressive decoder front-end."""
+
+    char_pos_encoder: PositionEncoder
+    pos_emb_alpha_char: Parameter
+    unit_pos_encoder: PositionEncoder
+    pos_emb_alpha: Parameter
+    scale: float
+    char_length_regulator: HardUpsampling
+    variance_adaptor: VarianceAdaptor
+    layer_norm: Optional[LayerNorm]
+    dropout: Optional[Dropout]
+
+    def __init__(
+        self,
+        embed: Embedding,
+        embed_char: Embedding,
+        text_tokenizer: NllbTokenizer,
+        char_tokenizer: CharTokenizer,
+        unit_pos_encoder: PositionEncoder,
+        char_pos_encoder: PositionEncoder,
+        variance_adaptor: VarianceAdaptor,
+        no_scale: bool = False,
+        layer_norm: bool = False,
+        dropout_p: float = 0.1,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        self.model_dim = embed.embedding_dim
+
+        super().__init__()
+
+        self.embed = embed
+        self.embed_char = embed_char
+        self.text_tokenizer = text_tokenizer
+        self.char_tokenizer = char_tokenizer
+        self.tag_manager = TagManager(text_tokenizer.vocab_info)
+
+        self.unk_idx = self.text_tokenizer.vocab_info.unk_idx
+        self.pad_idx = self.text_tokenizer.vocab_info.pad_idx
+
+        # TODO: Implement AlignmentEncoder for training.
+
+        if unit_pos_encoder.encoding_dim != self.model_dim:
+            raise ValueError(
+                f"`encoding_dim` of `unit_pos_encoder` and `embedding_dim` of `embed` must be equal, but are {unit_pos_encoder.encoding_dim} and {self.model_dim} instead."
+            )
+
+        if char_pos_encoder.encoding_dim != self.model_dim:
+            raise ValueError(
+                f"`encoding_dim` of `char_pos_encoder` and `embedding_dim` of `embed` must be equal, but are {char_pos_encoder.encoding_dim} and {self.model_dim} instead."
+            )
+
+        self.unit_pos_encoder = unit_pos_encoder
+
+        self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
+        self.char_pos_encoder = char_pos_encoder
+
+        self.pos_emb_alpha_char = Parameter(torch.ones(1, device=device, dtype=dtype))
+        self.scale = 1.0 if no_scale else math.sqrt(self.model_dim)
+
+        self.char_length_regulator = HardUpsampling()
+
+        self.variance_adaptor = variance_adaptor
+
+        if layer_norm:
+            self.layer_norm = create_default_layer_norm(
+                self.model_dim, device=device, dtype=dtype
+            )
+        else:
+            self.register_module("layer_norm", None)
+
+        if dropout_p > 0.0:
+            self.dropout = Dropout(dropout_p)
+        else:
+            self.register_module("dropout", None)
+
+    def indices_to_subwords(self, text_seqs: Tensor) -> List[List[str]]:
+        # TODO: To be replaced with fairseq2's indices_to_tokens SPM model method
+        # once implemented.
+        N, seq_len = text_seqs.shape
+        subwords_batch = []
+        for b in range(N):
+            subwords = []
+            for i in range(seq_len):
+                subword = self.text_tokenizer.model.index_to_token(int(text_seqs[b, i]))
+                subwords.append(str(subword))
+            subwords_batch.append(subwords)
+        return subwords_batch
+
+    def text_to_char_seqs(self, text_seqs: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+        text_seqs = self.tag_manager.preprocess_text_seqs(text_seqs)
+
+        subwords_batch = self.indices_to_subwords(text_seqs)
+
+        char_lens = self.count_character_length_in_subword(text_seqs, subwords_batch)
+
+        char_lens = self.tag_manager.postprocess_dur_or_len(char_lens)
+
+        char_seqs, char_seq_lens = self.get_char_seqs(
+            text_seqs, subwords_batch, char_lens
+        )
+
+        return char_seqs, char_seq_lens, char_lens
+
+    def count_character_length_in_subword(
+        self,
+        text_seqs: Tensor,
+        subwords_batch: List[List[str]],
+        merge_space_with_prev_subword: bool = False,
+    ) -> Tensor:
+        N, _ = text_seqs.shape
+
+        char_lens = text_seqs.new_zeros(text_seqs.size())
+
+        assert self.pad_idx is not None
+        subword_lens = text_seqs.ne(self.pad_idx).sum(1)
+
+        for b in range(N):
+            # We slice out the tensor till the padding index.
+            subword_indices = text_seqs[b, : subword_lens[b]]
+            subwords = subwords_batch[b][: subword_lens[b]]
+
+            assert subword_indices.shape[0] == len(subwords)
+
+            is_next_start_with_space = [
+                len(subwords[i + 1]) > 1 and subwords[i + 1][0] == SPACE
+                if i < len(subwords) - 1
+                else False
+                for i in range(len(subwords))
+            ]
+            is_punc = [
+                len(subwords[i]) == 1
+                and not subwords[i].isalpha()
+                and not subwords[i].isnumeric()
+                and subwords[i] != SPACE
+                for i in range(len(subwords))
+            ]
+            for i, (subword_idx, subword) in enumerate(zip(subword_indices, subwords)):
+                if subword_idx == self.pad_idx:
+                    break
+
+                if subword_idx == self.unk_idx:
+                    # We set char_len to 1 for an unk token.
+                    char_len = 1
+
+                    if merge_space_with_prev_subword and is_next_start_with_space[i]:
+                        char_len += 1
+                else:
+                    # By default, spaces are merged with the next subword.
+                    # char_len includes the space.
+                    char_len = len(subword)
+
+                    if merge_space_with_prev_subword:
+                        # Add the space for the next subword.
+                        if is_next_start_with_space[i]:
+                            char_len += 1
+                        # Subtract the space for the current subword.
+                        if i > 0 and is_next_start_with_space[i - 1]:
+                            char_len -= 1
+                    else:
+                        # Merge space with punctuation mark by default.
+                        if is_punc[i] and is_next_start_with_space[i]:
+                            char_len += 1
+                        # Subtract the space for the subword succeeding the punctuation mark.
+                        elif (
+                            i > 0 and is_punc[i - 1] and is_next_start_with_space[i - 1]
+                        ):
+                            char_len -= 1
+
+                char_lens[b, i] = char_len
+
+        return char_lens
+
+    def get_char_seqs(
+        self, text_seqs: Tensor, subwords_batch: List[List[str]], char_lens: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        N = text_seqs.shape[0]
+        max_len = int(char_lens.sum(1).max().item())
+
+        assert self.pad_idx is not None
+        char_seqs = text_seqs.new_zeros((N, max_len)).fill_(self.pad_idx)
+        char_seq_lens = char_seqs.new_zeros(N)
+
+        assert self.pad_idx is not None
+        subword_lens = text_seqs.ne(self.pad_idx).sum(1)
+
+        for b in range(N):
+            total = 0
+            subword_indices = text_seqs[b, : subword_lens[b]]
+            subwords = subwords_batch[b][: subword_lens[b]]
+            for subword_idx, subword in zip(subword_indices, subwords):
+                if subword_idx == self.unk_idx:
+                    char_ids = [self.unk_idx]
+                else:
+                    # Get char token indices corresponding to the subwords.
+                    char_ids = [
+                        self.char_tokenizer.model.token_to_index(ch)
+                        for ch in list(subword)
+                    ]
+                char_seq_len = len(char_ids)
+                char_seqs[b, total : total + char_seq_len] = torch.tensor(char_ids).to(
+                    char_seqs
+                )
+                total += char_seq_len
+            char_seq_lens[b] = total
+        return char_seqs, char_seq_lens
+
+    def character_level_upsampling(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+        char_seqs: Tensor,
+        char_lens: Tensor,
+    ) -> Tensor:
+        seqs, _ = self.char_length_regulator(seqs, char_lens)
+
+        pos_embeds = self.pos_emb_alpha_char * (
+            self.char_pos_encoder(seqs, padding_mask) - seqs
+        )
+
+        char_embeds = self.embed_char(char_seqs)
+
+        if self.scale != 1.0:
+            char_embeds *= self.scale
+
+        pos_embeds += char_embeds
+
+        seqs += pos_embeds
+
+        return seqs
+
+    def forward_unit_pos_embedding(
+        self, seqs: Tensor, padding_mask: Optional[Tensor]
+    ) -> Tensor:
+        pos_embeds = self.pos_emb_alpha * (
+            self.unit_pos_encoder(seqs, padding_mask) - seqs
+        )
+
+        seqs += pos_embeds
+
+        if self.dropout is not None:
+            seqs = self.dropout(seqs)
+
+        return seqs
+
+    @finaloverride
+    def forward(
+        self,
+        target_seqs: Optional[Tensor],
+        target_seq_lens: Optional[Tensor],
+        encoder_output: Tensor,
+        encoder_padding_mask: Optional[Tensor],
+        text_seqs: Optional[Tensor],
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        assert text_seqs is not None
+
+        # text_seqs: (N, S_text)
+        char_seqs, char_seq_lens, char_lens = self.text_to_char_seqs(text_seqs)
+
+        # char_seqs: (N, S_char)
+        encoder_padding_mask = to_padding_mask(char_seqs, char_seq_lens)
+
+        # (N, S_text, M) -> (N, S_char, M)
+        seqs = self.character_level_upsampling(
+            encoder_output, encoder_padding_mask, char_seqs, char_lens
+        )
+
+        # (N, S_char, M) -> (N, S_unit, M)
+        seqs, seq_lens = self.variance_adaptor(
+            seqs,
+            encoder_padding_mask,
+            min_duration=1,
+        )
+
+        decoder_padding_mask = to_padding_mask(seqs, seq_lens)
+
+        seqs = self.forward_unit_pos_embedding(seqs, decoder_padding_mask)
+
+        return seqs, decoder_padding_mask

+ 218 - 0
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -0,0 +1,218 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, final, Tuple
+
+from torch import Tensor
+from torch.nn import Conv1d, Dropout, Module, ReLU
+
+from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.transformer import (
+    TransformerDecoderLayer,
+    MultiheadAttention,
+)
+from fairseq2.nn.incremental_state import IncrementalStateBag
+from fairseq2.nn.transformer import create_default_layer_norm
+from fairseq2.nn.utils.mask import apply_padding_mask
+from fairseq2.nn.utils.module import check_model_dim
+from fairseq2.typing import DataType, Device, finaloverride
+
+
+@final
+class Conv1dBlock(Module):
+    """Represents the Conv1d block within the FFT Block as described in
+    :cite:t:`https://arxiv.org/pdf/1905.09263.pdf`."""
+
+    conv1: Conv1d
+    activation: ReLU
+    conv2: Conv1d
+
+    def __init__(
+        self,
+        model_dim: int,
+        inner_dim: int,
+        kernel_size: int,
+        bias: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param model_dim:
+            The dimensionality of the model.
+        :param inner_dim:
+            The inner dimensionality between the two convolutional layers.
+        :param kernel_size:
+            The kernel size of the Conv1d layers.
+        :param bias:
+            If ``True``, both the inner and output projections learn an additive
+            bias.
+        """
+        super().__init__()
+
+        self.conv1 = Conv1d(
+            model_dim,
+            inner_dim,
+            kernel_size,
+            stride=1,
+            padding="same",
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+
+        self.activation = ReLU()
+
+        self.conv2 = Conv1d(
+            inner_dim,
+            model_dim,
+            kernel_size,
+            stride=1,
+            padding="same",
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+
+    @finaloverride
+    def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+        # Ensure that we do not leak padded positions in the convolution layer.
+        seqs = apply_padding_mask(seqs, padding_mask)
+
+        # (N, S, M) -> (N, M, S)
+        seqs = seqs.transpose(1, 2)
+
+        # (N, M, S) -> (N, inner_dim, S)
+        seqs = self.conv1(seqs)
+
+        # (N, inner_dim, S) -> (N, S, inner_dim)
+        seqs = seqs.transpose(1, 2)
+
+        seqs = apply_padding_mask(seqs, padding_mask)
+
+        seqs = self.activation(seqs)
+
+        # (N, S, inner_dim) -> (N, inner_dim, S)
+        seqs = seqs.transpose(1, 2)
+
+        # (N, inner_dim, S) -> (N, M, S)
+        seqs = self.conv2(seqs)
+
+        # (N, M, S) -> (N, S, M)
+        seqs = seqs.transpose(1, 2)
+
+        return seqs
+
+
+@final
+class NARTransformerDecoderLayer(TransformerDecoderLayer):
+    """Represents the FFT Block as described in
+    :cite:t:`https://arxiv.org/pdf/1905.09263.pdf`."""
+
+    self_attn: MultiheadAttention
+    self_attn_dropout: Optional[Dropout]
+    self_attn_layer_norm: LayerNorm
+    conv1d: Conv1dBlock
+    conv1d_dropout: Optional[Dropout]
+    conv1d_layer_norm: LayerNorm
+
+    def __init__(
+        self,
+        self_attn: MultiheadAttention,
+        conv1d: Conv1dBlock,
+        dropout_p: float = 0.1,
+        conv1d_dropout_p: float = 0.1,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param self_attn:
+            The self attention layer.
+        :param conv1d:
+            The conv1d block.
+        :param dropout_p:
+            The dropout probability on the outputs of the self attention layer.
+        :param conv1d_dropout_p:
+            The dropout probability on the outputs of the conv1d block.
+        """
+        model_dim = self_attn.model_dim
+
+        super().__init__(model_dim)
+
+        self.self_attn = self_attn
+
+        if dropout_p > 0.0:
+            self.self_attn_dropout = Dropout(dropout_p)
+        else:
+            self.register_module("self_attn_dropout", None)
+
+        layer_norm_fn = create_default_layer_norm
+
+        self.self_attn_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+
+        self.conv1d = conv1d
+
+        if conv1d_dropout_p > 0.0:
+            self.conv1d_dropout = Dropout(conv1d_dropout_p)
+        else:
+            self.register_module("conv1d_dropout", None)
+
+        self.conv1d_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+
+        check_model_dim(self)
+
+    @finaloverride
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+        self_attn_mask: Optional[Tensor] = None,
+        encoder_output: Optional[Tensor] = None,
+        encoder_padding_mask: Optional[Tensor] = None,
+        state_bag: Optional[IncrementalStateBag] = None,
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        seqs = self._forward_self_attn(seqs, padding_mask)
+
+        seqs = self._forward_conv1d(seqs, padding_mask)
+
+        return seqs, padding_mask
+
+    def _forward_self_attn(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+    ) -> Tensor:
+        residual = seqs
+
+        seqs = self.self_attn(
+            seqs,
+            padding_mask,
+            keys=seqs,
+            values=seqs,
+            key_padding_mask=padding_mask,
+        )
+
+        if self.self_attn_dropout is not None:
+            seqs = self.self_attn_dropout(seqs)
+
+        seqs = seqs + residual
+
+        seqs = self.self_attn_layer_norm(seqs)
+
+        return seqs
+
+    def _forward_conv1d(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+        residual = seqs
+
+        seqs = self.conv1d(seqs, padding_mask)
+
+        if self.conv1d_dropout is not None:
+            seqs = self.conv1d_dropout(seqs)
+
+        seqs = seqs + residual
+
+        seqs = self.conv1d_layer_norm(seqs)
+
+        return seqs

+ 492 - 0
src/seamless_communication/models/unity/t2u_builder.py

@@ -0,0 +1,492 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Literal, Optional, Union
+
+from fairseq2.assets import download_manager
+from fairseq2.assets.card import AssetCard
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.nn.embedding import Embedding
+from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
+from fairseq2.nn.projection import TiedProjection
+from fairseq2.nn.transformer import (
+    FeedForwardNetwork,
+    MultiheadAttention,
+    StandardFeedForwardNetwork,
+    StandardMultiheadAttention,
+    StandardTransformerDecoder,
+    StandardTransformerDecoderLayer,
+    StandardTransformerEncoder,
+    StandardTransformerEncoderLayer,
+    TransformerDecoder,
+    TransformerDecoderLayer,
+    TransformerEncoder,
+    TransformerEncoderLayer,
+    TransformerNormOrder,
+    create_default_sdpa,
+)
+from fairseq2.typing import DataType, Device
+from fairseq2.models.transformer import (
+    TransformerEmbeddingFrontend,
+    TransformerFrontend,
+)
+from fairseq2.models.nllb.loader import NllbTokenizerLoader
+
+
+from seamless_communication.assets import asset_store
+from seamless_communication.models.unity.nar_decoder_layer import (
+    NARTransformerDecoderLayer,
+    Conv1dBlock,
+)
+from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
+from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
+from seamless_communication.models.unity.model import UnitYT2UModel, UnitYNART2UModel
+from seamless_communication.models.unity.length_regulator import (
+    VariancePredictor,
+    VarianceAdaptor,
+)
+
+
+@dataclass
+class VariancePredictorConfig:
+    var_pred_hidden_dim: int
+    var_pred_kernel_size: int
+    var_pred_dropout: float
+
+
+@dataclass
+class NARDecoderFrontendConfig:
+    subword_to_unit_upsampling_type: Literal["gaussian", "hard"]
+    duration_predictor_config: VariancePredictorConfig
+    pitch_predictor_config: Optional[VariancePredictorConfig]
+    energy_predictor_config: Optional[VariancePredictorConfig]
+
+
+@dataclass
+class NARDecoderConfig:
+    model_name_or_card: Union[str, AssetCard]
+    char_vocabulary_size: int
+    char_max_seq_len: int
+    conv1d_kernel_size: int
+    conv1d_inner_dim: int
+    conv1d_dropout_p: float
+
+
+@dataclass
+class UnitYT2UConfig:
+    """Holds the configuration of a UnitY T2U model as described in
+    :cite:t`https://doi.org/10.48550/arxiv.2212.08055`"""
+
+    model_dim: int
+    """The dimensionality of the model."""
+
+    unit_max_seq_len: int
+    """The expected maximum unit sequence length."""
+
+    unit_vocabulary_size: int
+    """The size of the unit vocabulary."""
+
+    unit_pad_idx: Optional[int]
+    """The index of the pad symbol in the unit vocabulary."""
+
+    num_encoder_layers: int
+    """The number of Transformer encoder layers."""
+
+    num_decoder_layers: int
+    """The number of Transformer decoder layers."""
+
+    nar_decoder_frontend_config: Optional[NARDecoderFrontendConfig]
+    """Non-autoregressive decoder front-end config."""
+
+    nar_decoder_config: Optional[NARDecoderConfig]
+    """Non-autoregressive decoder config."""
+
+    num_encoder_attn_heads: int
+    """The number of attention heads in Transformer encoder layers."""
+
+    num_decoder_attn_heads: int
+    """The number of attention heads in Transformer decoder layers."""
+
+    ffn_inner_dim: int
+    """The inner dimensionality of Transformer feed-forward networks."""
+
+    dropout_p: float
+    """The dropout probability in Transformer layers."""
+
+    def update_unit_vocabulary(self, info: VocabularyInfo) -> None:
+        """Update unit vocabulary configuration from ``info``."""
+        self.unit_vocabulary_size, self.unit_pad_idx = info.size, info.pad_idx
+
+
+unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
+
+
+unity_t2u_arch = unity_t2u_archs.marker
+
+
+@unity_t2u_arch("base")
+def _base_t2u() -> UnitYT2UConfig:
+    return UnitYT2UConfig(
+        model_dim=1024,
+        unit_max_seq_len=2048,
+        unit_vocabulary_size=10082,
+        unit_pad_idx=1,
+        num_encoder_layers=6,
+        num_decoder_layers=6,
+        nar_decoder_frontend_config=None,
+        nar_decoder_config=None,
+        num_encoder_attn_heads=16,
+        num_decoder_attn_heads=16,
+        ffn_inner_dim=1024 * 8,
+        dropout_p=0.1,
+    )
+
+
+@unity_t2u_arch("medium")
+def _medium_t2u() -> UnitYT2UConfig:
+    return UnitYT2UConfig(
+        model_dim=1024,
+        unit_max_seq_len=2048,
+        unit_vocabulary_size=10082,
+        unit_pad_idx=1,
+        num_encoder_layers=4,
+        num_decoder_layers=4,
+        nar_decoder_frontend_config=None,
+        nar_decoder_config=None,
+        num_encoder_attn_heads=16,
+        num_decoder_attn_heads=16,
+        ffn_inner_dim=1024 * 8,
+        dropout_p=0.1,
+    )
+
+
+@unity_t2u_arch("nar_multilingual")
+def _nar_multilingual_t2u() -> UnitYT2UConfig:
+    duration_predictor_config = VariancePredictorConfig(
+        var_pred_hidden_dim=256,
+        var_pred_kernel_size=3,
+        var_pred_dropout=0.5,
+    )
+
+    nar_decoder_frontend_config = NARDecoderFrontendConfig(
+        subword_to_unit_upsampling_type="hard",
+        duration_predictor_config=duration_predictor_config,
+        pitch_predictor_config=None,
+        energy_predictor_config=None,
+    )
+
+    nar_decoder_config = NARDecoderConfig(
+        model_name_or_card="unity_nar_multilingual",
+        char_vocabulary_size=10904,
+        char_max_seq_len=4096,
+        conv1d_kernel_size=7,
+        conv1d_inner_dim=1024,
+        conv1d_dropout_p=0.1,
+    )
+
+    return UnitYT2UConfig(
+        model_dim=1024,
+        unit_max_seq_len=2048,
+        unit_vocabulary_size=10020,
+        unit_pad_idx=1,
+        num_encoder_layers=6,
+        num_decoder_layers=6,
+        nar_decoder_frontend_config=nar_decoder_frontend_config,
+        nar_decoder_config=nar_decoder_config,
+        num_encoder_attn_heads=16,
+        num_decoder_attn_heads=16,
+        ffn_inner_dim=1024 * 8,
+        dropout_p=0.0,
+    )
+
+
+class UnitYT2UBuilder:
+    """Builds modules of an AR or NAR UnitY T2U model.
+
+    To tweak the architecture, you can derive from this class and override the
+    corresponding methods.
+    """
+
+    config: UnitYT2UConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: UnitYT2UConfig,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+        self.device = device
+        self.dtype = dtype
+
+    def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]:
+        """Build a model."""
+        embed_unit = self.build_unit_embedding()
+
+        encoder = self.build_encoder()
+
+        decoder = self.build_decoder()
+
+        final_proj = TiedProjection(embed_unit.weight, bias=None)
+
+        if self.config.nar_decoder_config is None:
+            decoder_frontend = self.build_decoder_frontend(embed_unit)
+            return UnitYT2UModel(
+                encoder,
+                decoder_frontend,
+                decoder,
+                final_proj,
+                self.config.unit_pad_idx,
+            )
+        else:
+            nar_decoder_frontend = self.build_nar_decoder_frontend(embed_unit)
+            return UnitYNART2UModel(
+                encoder,
+                nar_decoder_frontend,
+                decoder,
+                final_proj,
+                self.config.unit_pad_idx,
+            )
+
+    def build_unit_embedding(self) -> Embedding:
+        """Build a unit embedding table."""
+        return Embedding(
+            num_embeddings=self.config.unit_vocabulary_size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=self.config.unit_pad_idx,
+            scaled=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_encoder(self) -> Optional[TransformerEncoder]:
+        """Build a Transformer encoder."""
+        num_layers = self.config.num_encoder_layers
+        if num_layers == 0:
+            return None
+
+        layers = [self.build_encoder_layer() for _ in range(num_layers)]
+
+        return StandardTransformerEncoder(
+            layers,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_encoder_layer(self) -> TransformerEncoderLayer:
+        """Build a Transformer encoder layer."""
+        self_attn = self.build_attention(self.config.num_encoder_attn_heads)
+
+        ffn = self.build_ffn()
+
+        return StandardTransformerEncoderLayer(
+            self_attn,
+            ffn,
+            dropout_p=self.config.dropout_p,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_variance_adaptor(
+        self, nar_decoder_frontend_config: NARDecoderFrontendConfig
+    ) -> VarianceAdaptor:
+        duration_predictor_config = (
+            nar_decoder_frontend_config.duration_predictor_config
+        )
+        duration_predictor = VariancePredictor(
+            self.config.model_dim,
+            duration_predictor_config.var_pred_hidden_dim,
+            duration_predictor_config.var_pred_kernel_size,
+            duration_predictor_config.var_pred_dropout,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        variance_adaptor = VarianceAdaptor(
+            duration_predictor,
+            pitch_predictor=None,
+            energy_predictor=None,
+        )
+
+        return variance_adaptor
+
+    def build_decoder_frontend(self, embed_unit: Embedding) -> TransformerFrontend:
+        """Build a Transformer decoder front-end."""
+        pos_encoder = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.unit_max_seq_len,
+            _legacy_pad_idx=self.config.unit_pad_idx,
+            device=self.device,
+            dtype=self.dtype,
+        )
+        return TransformerEmbeddingFrontend(
+            embed_unit,
+            pos_encoder,
+            dropout_p=self.config.dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_nar_decoder_frontend(self, embed_unit: Embedding) -> NARDecoderFrontend:
+        """Build a non-autoregressive decoder front-end."""
+        assert self.config.nar_decoder_config is not None
+        assert self.config.nar_decoder_frontend_config is not None
+
+        unit_pos_encoder = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.unit_max_seq_len,
+            _legacy_pad_idx=self.config.unit_pad_idx,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        char_tokenizer = load_unity_char_tokenizer(
+            self.config.nar_decoder_config.model_name_or_card
+        )
+
+        variance_adaptor = self.build_variance_adaptor(
+            self.config.nar_decoder_frontend_config
+        )
+
+        nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
+            self.config.nar_decoder_config.model_name_or_card
+        )
+        text_pad_idx = nllb_tokenizer.vocab_info.pad_idx
+
+        char_pos_encoder = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.nar_decoder_config.char_max_seq_len,
+            _legacy_pad_idx=text_pad_idx,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_char = Embedding(
+            num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=text_pad_idx,
+            scaled=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return NARDecoderFrontend(
+            embed_unit,
+            embed_char,
+            nllb_tokenizer,
+            char_tokenizer,
+            unit_pos_encoder,
+            char_pos_encoder,
+            variance_adaptor,
+            dropout_p=self.config.dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder(self) -> TransformerDecoder:
+        """Build a Transformer decoder."""
+        num_layers = self.config.num_decoder_layers
+
+        layers = [self.build_decoder_layer() for _ in range(num_layers)]
+
+        return StandardTransformerDecoder(
+            layers,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_decoder_layer(self) -> TransformerDecoderLayer:
+        """Build a Transformer decoder layer."""
+        self_attn = self.build_attention(self.config.num_decoder_attn_heads)
+
+        if self.config.nar_decoder_config:
+            conv1d = Conv1dBlock(
+                self.config.model_dim,
+                self.config.nar_decoder_config.conv1d_inner_dim,
+                self.config.nar_decoder_config.conv1d_kernel_size,
+                bias=True,
+                device=self.device,
+                dtype=self.dtype,
+            )
+
+            return NARTransformerDecoderLayer(
+                self_attn,
+                conv1d,
+                dropout_p=self.config.dropout_p,
+                conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
+                device=self.device,
+                dtype=self.dtype,
+            )
+        else:
+            encoder_decoder_attn = self.build_attention(
+                self.config.num_decoder_attn_heads
+            )
+
+            ffn = self.build_ffn()
+
+            return StandardTransformerDecoderLayer(
+                self_attn,
+                encoder_decoder_attn,
+                ffn,
+                dropout_p=self.config.dropout_p,
+                norm_order=TransformerNormOrder.PRE,
+                device=self.device,
+                dtype=self.dtype,
+            )
+
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
+        """Build a Transformer multi-head attention layer."""
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
+
+        return StandardMultiheadAttention(
+            self.config.model_dim,
+            num_heads,
+            sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_ffn(self) -> FeedForwardNetwork:
+        """Build a Transformer feed-forward network."""
+        return StandardFeedForwardNetwork(
+            self.config.model_dim,
+            self.config.ffn_inner_dim,
+            bias=True,
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+def create_unity_t2u_model(
+    config: UnitYT2UConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> Union[UnitYT2UModel, UnitYNART2UModel]:
+    """Create a UnitY T2U model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+    return UnitYT2UBuilder(config, device, dtype).build_model()

+ 41 - 14
src/seamless_communication/models/unity/unit_tokenizer.py

@@ -1,4 +1,4 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright (c) Meta Platforms, Inc. and affiliates
 # All rights reserved.
 #
 # This source code is licensed under the license found in the
@@ -19,22 +19,31 @@ class UnitTokenizer:
     langs: Sequence[str]
     lang_map: Dict[str, int]
 
-    def __init__(self, num_units: int, langs: Sequence[str]) -> None:
+    def __init__(self, num_units: int, langs: Sequence[str], model_arch: str) -> None:
         """
         :param num_units:
             The number of speech units.
         :param langs:
             The list of supported languages.
+        :param model_arch:
+            The type of UnitY model architecture.
         """
         self.num_units = num_units
 
         self.langs = langs
 
+        self.model_arch = model_arch
+
         self.lang_map = {lang: idx for idx, lang in enumerate(langs)}
 
-        # For legacy reasons, we have to repeat the language symbols twice,
-        # along with a placeholder `<mask>` token.
-        vocab_size = num_units + (2 * (len(langs) + 1)) + 4
+        if self.model_arch == "nar_multilingual":
+            self.lang_symbol_repititions = 1
+        else:
+            # For legacy reasons, we have to repeat the language symbols twice,
+            # along with a placeholder `<mask>` token for UnitY AR models.
+            self.lang_symbol_repititions = 2
+
+        vocab_size = num_units + self.lang_symbol_repititions * (len(langs) + 1) + 4
 
         # We use fairseq's control symbol order.
         self.vocab_info = VocabularyInfo(
@@ -45,7 +54,12 @@ class UnitTokenizer:
         """Return the symbol index of the specified language."""
         # +4 for PAD/EOS/BOS/UNK, and +1 for the `<mask>` token.
         try:
-            return self.num_units + len(self.langs) + self.lang_map[lang] + 5
+            return (
+                self.num_units
+                + (self.lang_symbol_repititions - 1) * len(self.langs)
+                + self.lang_map[lang]
+                + 5
+            )
         except KeyError:
             langs = ", ".join(self.langs)
 
@@ -55,7 +69,12 @@ class UnitTokenizer:
 
     def index_to_lang(self, idx: int) -> str:
         """Return the language of the specified language symbol index."""
-        relative_idx = idx - self.num_units - len(self.langs) - 5
+        relative_idx = (
+            idx
+            - self.num_units
+            - (self.lang_symbol_repititions - 1) * len(self.langs)
+            - 5
+        )
 
         if relative_idx < 0 or relative_idx >= len(self.langs):
             raise ValueError(
@@ -76,7 +95,7 @@ class UnitTokenizer:
 
     def create_decoder(self) -> "UnitTokenDecoder":
         """Create a token decoder."""
-        return UnitTokenDecoder(self)
+        return UnitTokenDecoder(self, self.model_arch)
 
 
 class UnitTokenEncoder:
@@ -158,16 +177,19 @@ class UnitTokenDecoder:
     eos_idx: int
     pad_idx: int
 
-    def __init__(self, tokenizer: UnitTokenizer) -> None:
+    def __init__(self, tokenizer: UnitTokenizer, model_arch: str) -> None:
         """
         :param tokenizer:
             The unit tokenizer to use.
+        :param model_arch:
+            The type of UnitY model architecture.
         """
         assert tokenizer.vocab_info.eos_idx is not None
         assert tokenizer.vocab_info.pad_idx is not None
 
         self.eos_idx = tokenizer.vocab_info.eos_idx
         self.pad_idx = tokenizer.vocab_info.pad_idx
+        self.model_arch = model_arch
 
     def __call__(self, token_indices: Tensor) -> Tensor:
         """Decode ``token_indices`` to speech units.
@@ -184,16 +206,21 @@ class UnitTokenDecoder:
         if token_indices.size(1) == 0:
             return token_indices
 
-        # Remove the prefix EOS symbol. The language symbol is still expected to
-        # be part of the decoded output.
-        units = token_indices[:, 1:].clone().detach()
+        units = token_indices.clone().detach()
+
+        # Remove the prefix EOS symbol from the decoded output for AR UnitY.
+        if self.model_arch != "nar_multilingual":
+            units = units[:, 1:]
 
         # Also, replace EOS with PAD at sequence ends.
         units[units == self.eos_idx] = self.pad_idx
 
         units[units == self.pad_idx] = self.pad_idx + 4
 
-        # Remove offset of control symbols (exclude language symbol).
-        units[:, 1:] -= 4
+        # Remove offset of control symbols (exclude language symbol for AR UnitY).
+        if self.model_arch == "nar_multilingual":
+            units -= 4
+        else:
+            units[:, 1:] -= 4
 
         return units

+ 2 - 43
src/seamless_communication/models/vocoder/codehifigan.py

@@ -8,50 +8,9 @@ from typing import Any, Dict, List, Optional
 import torch
 import torch.nn as nn
 from torch import Tensor
-from torch.nn import Dropout
 
 from seamless_communication.models.vocoder.hifigan import Generator
-
-
-class VariancePredictor(nn.Module):
-    def __init__(
-        self,
-        encoder_embed_dim: int,
-        var_pred_hidden_dim: int,
-        var_pred_kernel_size: int,
-        var_pred_dropout: float,
-    ):
-        super().__init__()
-        self.conv1 = nn.Sequential(
-            nn.Conv1d(
-                encoder_embed_dim,
-                var_pred_hidden_dim,
-                kernel_size=var_pred_kernel_size,
-                padding=(var_pred_kernel_size - 1) // 2,
-            ),
-            nn.ReLU(),
-        )
-        self.ln1 = nn.LayerNorm(var_pred_hidden_dim)
-        self.dropout_module = Dropout(p=var_pred_dropout)
-        self.conv2 = nn.Sequential(
-            nn.Conv1d(
-                var_pred_hidden_dim,
-                var_pred_hidden_dim,
-                kernel_size=var_pred_kernel_size,
-                padding=1,
-            ),
-            nn.ReLU(),
-        )
-        self.ln2 = nn.LayerNorm(var_pred_hidden_dim)
-        self.proj = nn.Linear(var_pred_hidden_dim, 1)
-
-    def forward(self, x: Tensor) -> Any:
-        # Input: B x T x C; Output: B x T
-        x = self.conv1(x.transpose(1, 2)).transpose(1, 2)
-        x = self.dropout_module(self.ln1(x))
-        x = self.conv2(x.transpose(1, 2)).transpose(1, 2)
-        x = self.dropout_module(self.ln2(x))
-        return self.proj(x).squeeze(dim=2)
+from seamless_communication.models.unity import VariancePredictor
 
 
 class CodeGenerator(Generator):
@@ -119,7 +78,7 @@ class CodeGenerator(Generator):
 
         if self.dur_predictor and dur_prediction:
             assert x.size(0) == 1, "only support single sample"
-            log_dur_pred = self.dur_predictor(x.transpose(1, 2))
+            log_dur_pred = self.dur_predictor(x.transpose(1, 2), None)
             dur_out = torch.clamp(
                 torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1
             )