瀏覽代碼

Introduce `wav2vec2_chunk` offline streaming-compatible speech encoder. (#42)

* Make SC compatible with Transformer Encoder API change.

* Introduce wav2vec2_chunk offline streaming-compatible speech encoder.

* Make depthwise convolution causal, within Conformer convolution module.

* Add an asset and arch for m4t_v2_s2t.

* Change variable names to be compatible with fairseq2 changes.

* Refactor wav2vec2 encoder config, address comments.

* Specify pos_encoder_type for wav2vec2_chunk_arch.
Kaushik Ram Sadagopan 1 年之前
父節點
當前提交
c7f576a749

+ 10 - 0
src/seamless_communication/assets/cards/m4t_v2_s2t.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: m4t_v2_s2t
+base: unity_nllb-100
+model_arch: m4t_v2_s2t
+checkpoint: "file://large_experiments/seamless/ust/elbayadm/multitasking_models/m4t_v2_s2t.pt"

+ 10 - 0
src/seamless_communication/assets/cards/s2t_chunk_conformer.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: s2t_chunk_conformer
+base: unity_nllb-200
+model_arch: s2t_chunk_conformer
+checkpoint: "file://checkpoint/andyyuan/ckpt_from_rsc/w2vbert-2.0/S2T/avg_last_5_checkpoint.pt"

+ 15 - 5
src/seamless_communication/models/unity/adaptor_block.py

@@ -238,16 +238,22 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
 
     @finaloverride
     def forward(
-        self, seqs: Tensor, padding_mask: Optional[Tensor]
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+        self_attn_mask: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self._forward_self_attn(seqs, padding_mask)
+        seqs, padding_mask = self._forward_self_attn(seqs, padding_mask, self_attn_mask)
 
         seqs = self._forward_ffn(seqs)
 
         return seqs, padding_mask
 
     def _forward_self_attn(
-        self, seqs: Tensor, padding_mask: Optional[Tensor]
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+        self_attn_mask: Optional[Tensor],
     ) -> Tuple[Tensor, Optional[Tensor]]:
         residual = self.residual_layer_norm(seqs)
 
@@ -287,6 +293,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
             padding_mask,
             keys=seqs,
             values=seqs,
+            attn_mask=self_attn_mask,
             key_padding_mask=padding_mask,
         )
 
@@ -385,7 +392,10 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
 
     @finaloverride
     def forward(
-        self, seqs: Tensor, padding_mask: Optional[Tensor]
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+        self_attn_mask: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[Tensor]]:
         if self.layer_norm is not None:
             seqs = self.layer_norm(seqs)
@@ -405,7 +415,7 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
             seqs, padding_mask, self.kernel_size, self.stride
         )
 
-        return self.block(seqs, padding_mask)  # type: ignore[no-any-return]
+        return self.block(seqs, padding_mask, self_attn_mask)  # type: ignore[no-any-return]
 
     def extra_repr(self) -> str:
         """:meta private:"""

+ 64 - 3
src/seamless_communication/models/unity/builder.py

@@ -5,6 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
+from torch.nn import Parameter
 from typing import Optional
 
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
@@ -35,6 +36,11 @@ from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UConfig,
     unity_t2u_archs,
 )
+from seamless_communication.models.wav2vec2_chunk import (
+    wav2vec2_chunk_archs,
+    Wav2Vec2ChunkEncoderBuilder,
+    Wav2Vec2ChunkEncoderConfig,
+)
 
 
 @dataclass
@@ -133,6 +139,54 @@ def _medium() -> UnitYConfig:
     )
 
 
+@unity_arch("m4t_v2_s2t")
+def _m4t_v2_s2t() -> UnitYConfig:
+    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
+
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
+
+    mt_model_config.vocabulary_size = 256102  # NLLB-100
+
+    mt_model_config.max_seq_len = 4096
+
+    return UnitYConfig(
+        model_dim=1024,
+        w2v2_encoder_config=w2v2_chunk_encoder_config,
+        mt_model_config=mt_model_config,
+        t2u_config=None,
+        use_text_encoder=False,
+        use_conformer_adaptor=False,
+        num_adaptor_layers=1,
+        adaptor_kernel_size=8,
+        adaptor_stride=8,
+        adaptor_layer_norm=True,
+        adaptor_dropout_p=0.0,
+    )
+
+
+@unity_arch("s2t_chunk_conformer")
+def _s2t_chunk_conformer() -> UnitYConfig:
+    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
+
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
+
+    mt_model_config.max_seq_len = 4096
+
+    return UnitYConfig(
+        model_dim=1024,
+        w2v2_encoder_config=w2v2_chunk_encoder_config,
+        mt_model_config=mt_model_config,
+        t2u_config=None,
+        use_text_encoder=False,
+        use_conformer_adaptor=False,
+        num_adaptor_layers=1,
+        adaptor_kernel_size=8,
+        adaptor_stride=8,
+        adaptor_layer_norm=True,
+        adaptor_dropout_p=0.0,
+    )
+
+
 @unity_arch("nar_multilingual")
 def _nar_multilingual() -> UnitYConfig:
     w2vbert_config = w2vbert_archs.get_config("600m")
@@ -239,6 +293,8 @@ class UnitYBuilder:
             text_encoder_frontend = None
             text_encoder = None
 
+        assert isinstance(text_embed.weight, Parameter)
+
         final_proj = TiedProjection(text_embed.weight, bias=None)
 
         if self.t2u_builder is None:
@@ -374,9 +430,14 @@ def create_unity_model(
     :param dtype:
         The data type of module parameters and buffers.
     """
-    w2v2_encoder_builder = Wav2Vec2EncoderBuilder(
-        config.w2v2_encoder_config, device=device, dtype=dtype
-    )
+    if isinstance(config.w2v2_encoder_config, Wav2Vec2ChunkEncoderConfig):
+        w2v2_encoder_builder: Wav2Vec2EncoderBuilder = Wav2Vec2ChunkEncoderBuilder(
+            config.w2v2_encoder_config, device=device, dtype=dtype
+        )
+    else:
+        w2v2_encoder_builder = Wav2Vec2EncoderBuilder(
+            config.w2v2_encoder_config, device=device, dtype=dtype
+        )
 
     if config.t2u_config is None:
         t2u_builder = None

+ 99 - 61
src/seamless_communication/models/unity/loader.py

@@ -46,31 +46,38 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
         state_dict = checkpoint["model"]
 
+        keys_to_delete = []
+
         # Use the built-in version attribute of `torch.Module`.
-        del state_dict["target_letter_decoder.version"]
-        del state_dict["target_letter_decoder.embed_positions._float_tensor"]
+        if config.t2u_config is None:
+            keys_to_delete.append("decoder.version")
+            keys_to_delete.append("decoder.embed_positions._float_tensor")
+        else:
+            keys_to_delete.append("target_letter_decoder.version")
+            keys_to_delete.append("target_letter_decoder.embed_positions._float_tensor")
 
         if config.use_text_encoder:
-            if "text_encoder.version" in state_dict:
-                del state_dict["text_encoder.version"]
-            if "text_encoder.embed_positions._float_tensor" in state_dict:
-                del state_dict["text_encoder.embed_positions._float_tensor"]
+            keys_to_delete.append("text_encoder.version")
+            keys_to_delete.append("text_encoder.embed_positions._float_tensor")
 
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
-        del state_dict["encoder.w2v_encoder.w2v_model.mask_emb"]
+        keys_to_delete.append("encoder.w2v_encoder.w2v_model.mask_emb")
 
         # Delete AlignmentEncoder keys for inference.
         alignment_encoder_keys = [
             key for key in state_dict if key.startswith("decoder.alignment_encoder.")
         ]
-        for key in alignment_encoder_keys:
-            del state_dict[key]
+        keys_to_delete.extend(alignment_encoder_keys)
 
         # Delete character-level projection for inference.
-        for key in [
-            "decoder_target_letter_decoder.proj.weight",
-            "decoder_target_letter_decoder.proj.bias",
-        ]:
+        keys_to_delete.extend(
+            [
+                "decoder_target_letter_decoder.proj.weight",
+                "decoder_target_letter_decoder.proj.bias",
+            ]
+        )
+
+        for key in keys_to_delete:
             if key in state_dict:
                 del state_dict[key]
 
@@ -131,6 +138,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
 
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
+            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     r"speech_encoder.inner.layers.\1.conv.layer_norm.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
@@ -143,6 +151,11 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
+            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
+            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
+            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
+            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
+            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
             r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
@@ -166,54 +179,6 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             r"^text_encoder\.layers\.([0-9]+)\.fc2\.":                     r"text_encoder.layers.\1.ffn.output_proj.",
             r"^text_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_encoder.layers.\1.ffn_layer_norm.",
             r"^text_encoder\.layer_norm\.":                                r"text_encoder.layer_norm.",
-
-            # Text Decoder
-            r"^target_letter_decoder\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
-            r"^target_letter_decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
-            r"^target_letter_decoder\.layer_norm\.":                                r"text_decoder.layer_norm.",
-            r"^target_letter_decoder\.output_projection\.":                         r"final_proj.",
-
-            # T2U Encoder
-            r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
-            r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
-            r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
-            r"^synthesizer_encoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
-            r"^synthesizer_encoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
-            r"^synthesizer_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
-            r"^synthesizer_encoder\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
-
-            # T2U Decoder frontend
-            r"^decoder\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
-            r"^decoder\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
-            r"^decoder\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
-            r"^decoder\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
-            r"^decoder\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
-            r"^decoder\.dec_pos_emb_alpha_char":                        r"t2u_model.decoder_frontend.pos_emb_alpha_char",
-
-            # T2U Decoder
-            r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
-            r"^decoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
-            r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-            r"^decoder\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
-            r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-            r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
-            r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
-            r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
-            r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
-            r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
-            r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
-            r"^decoder\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
-            r"^decoder\.output_projection\.":                         r"t2u_model.final_proj.",
             # fmt: on
         }
 
@@ -269,6 +234,79 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
                     r"^encoder\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
                 }
             )
+
+        # S2T model.
+        if config.t2u_config is None:
+            key_map.update(
+                {
+                    # Text Decoder
+                    r"^decoder\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
+                    r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
+                    r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
+                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
+                    r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+                    r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
+                    r"^decoder\.layer_norm\.":                                r"text_decoder.layer_norm.",
+                    r"^decoder\.output_projection\.":                         r"final_proj.",
+                }
+            )
+        # S2T + T2U model.
+        else:
+            key_map.update(
+                {
+                    # Text Decoder
+                    r"^target_letter_decoder\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
+                    r"^target_letter_decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
+                    r"^target_letter_decoder\.layer_norm\.":                                r"text_decoder.layer_norm.",
+                    r"^target_letter_decoder\.output_projection\.":                         r"final_proj.",
+
+                    # T2U Encoder
+                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
+                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
+                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
+                    r"^synthesizer_encoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
+                    r"^synthesizer_encoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
+                    r"^synthesizer_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
+                    r"^synthesizer_encoder\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
+
+                    # T2U Decoder frontend
+                    r"^decoder\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
+                    r"^decoder\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
+                    r"^decoder\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
+                    r"^decoder\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
+                    r"^decoder\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
+                    r"^decoder\.dec_pos_emb_alpha_char":                        r"t2u_model.decoder_frontend.pos_emb_alpha_char",
+
+                    # T2U Decoder
+                    r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
+                    r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                    r"^decoder\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
+                    r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+                    r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
+                    r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
+                    r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
+                    r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
+                    r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
+                    r"^decoder\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
+                    r"^decoder\.output_projection\.":                         r"t2u_model.final_proj.",
+                }
+            )
         # fmt: on
 
         return key_map

+ 3 - 0
src/seamless_communication/models/unity/t2u_builder.py

@@ -4,6 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 from dataclasses import dataclass
+from torch.nn import Parameter
 from typing import Literal, Optional, Union
 
 from fairseq2.assets import download_manager
@@ -241,6 +242,8 @@ class UnitYT2UBuilder:
 
         decoder = self.build_decoder()
 
+        assert isinstance(embed_unit.weight, Parameter)
+
         final_proj = TiedProjection(embed_unit.weight, bias=None)
 
         if self.config.nar_decoder_config is None:

+ 15 - 0
src/seamless_communication/models/wav2vec2_chunk/__init__.py

@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.models.wav2vec2_chunk.builder import (
+    wav2vec2_chunk_archs as wav2vec2_chunk_archs,
+)
+from seamless_communication.models.wav2vec2_chunk.builder import (
+    Wav2Vec2ChunkEncoderBuilder as Wav2Vec2ChunkEncoderBuilder,
+)
+from seamless_communication.models.wav2vec2_chunk.builder import (
+    Wav2Vec2ChunkEncoderConfig as Wav2Vec2ChunkEncoderConfig,
+)

+ 158 - 0
src/seamless_communication/models/wav2vec2_chunk/builder.py

@@ -0,0 +1,158 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass, asdict
+from typing import Literal, Optional
+
+from fairseq2.models.conformer import ConformerConvolution
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.wav2vec2.builder import (
+    Wav2Vec2EncoderBuilder,
+    Wav2Vec2EncoderConfig,
+)
+from fairseq2.models.w2vbert import w2vbert_archs
+from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.models.wav2vec2_chunk.encoder import ChunkTransformerEncoder
+
+
+@dataclass
+class ShawRelativePositionSDPAConfig:
+    """Holds the configuration of the :class:ShawRelativePositionSDPA module."""
+
+    max_left_rel_pos: int
+    """The left clipping value for relative positions."""
+
+    max_right_rel_pos: Optional[int]
+    """The right clipping value for relative positions."""
+
+    use_rel_pos_values: bool = False
+    """If True, also uses relative position values to compute relative attention."""
+
+
+@dataclass
+class Wav2Vec2ChunkEncoderConfig(Wav2Vec2EncoderConfig):
+    """Holds the configuration of a wav2vec2 chunk encoder."""
+
+    causal_depthwise_conv: bool
+    """If True, uses a causal depthwise convolution similar to that described in
+    Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`."""
+
+    conv_norm_type: Literal["batch_norm", "layer_norm"]
+    """The type of normalization to use in the Conformer convolution module."""
+
+    shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
+    """The parameters for ShawRelativePositionSDPA."""
+
+    chunk_size: int
+    """The size of each chunk."""
+
+    left_chunk_num: int
+    """Number of chunks on the left up to which lookahead is allowed."""
+
+    right_chunk_num: int
+    """Number of chunks on the right up to which lookahead is allowed."""
+
+
+wav2vec2_chunk_archs = ArchitectureRegistry[Wav2Vec2ChunkEncoderConfig](
+    "wav2vec2_chunk"
+)
+
+wav2vec2_chunk_arch = wav2vec2_chunk_archs.marker
+
+
+@wav2vec2_chunk_arch("600m")
+def _encoder_600m() -> Wav2Vec2ChunkEncoderConfig:
+    w2vbert_config = w2vbert_archs.get_config("600m")
+    w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
+    sdpa_config = ShawRelativePositionSDPAConfig(
+        max_left_rel_pos=64,
+        max_right_rel_pos=8,
+        use_rel_pos_values=False,
+    )
+    w2v2_chunk_encoder_config = Wav2Vec2ChunkEncoderConfig(
+        **asdict(w2v2_encoder_config),
+        causal_depthwise_conv=True,
+        conv_norm_type="layer_norm",
+        shaw_rel_pos_sdpa_config=sdpa_config,
+        chunk_size=10000,
+        left_chunk_num=128,
+        right_chunk_num=0,
+    )
+    w2v2_chunk_encoder_config.pos_encoder_type = "shaw_relative"
+    return w2v2_chunk_encoder_config
+
+
+class Wav2Vec2ChunkEncoderBuilder(Wav2Vec2EncoderBuilder):
+    config: Wav2Vec2ChunkEncoderConfig
+
+    def __init__(
+        self,
+        config: Wav2Vec2ChunkEncoderConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        super().__init__(config, device=device, dtype=dtype)
+
+        assert (
+            self.config.use_conformer
+        ), "Currently we only support the ChunkConformerBlock."
+
+    def build_encoder(self) -> ChunkTransformerEncoder:
+        """Build a Transformer encoder."""
+        num_layers = self.config.num_encoder_layers
+
+        layers = [self.build_encoder_layer() for _ in range(num_layers)]
+
+        return ChunkTransformerEncoder(
+            layers,
+            self.config.chunk_size,
+            self.config.left_chunk_num,
+            self.config.right_chunk_num,
+            dropout_p=self.config.dropout_p,
+            layer_drop_p=self.config.layer_drop_p,
+        )
+
+    def build_sdpa(self) -> SDPA:
+        if self.config.pos_encoder_type == "shaw_relative":
+            if self.config.shaw_rel_pos_sdpa_config is None:
+                raise ValueError(
+                    "`shaw_rel_pos_sdpa_config` must be specified when `pos_encoder_type` is 'shaw_relative'."
+                )
+
+            sdpa_config = self.config.shaw_rel_pos_sdpa_config
+            return ShawRelativePositionSDPA(
+                self.config.model_dim,
+                self.config.num_encoder_attn_heads,
+                sdpa_config.max_left_rel_pos,
+                max_right_rel_pos=sdpa_config.max_right_rel_pos,
+                use_rel_pos_values=sdpa_config.use_rel_pos_values,
+                attn_dropout_p=self.config.attn_dropout_p,
+                device=self.device,
+                dtype=self.dtype,
+            )
+
+        return super().build_sdpa()
+
+    def build_conformer_conv(self) -> ConformerConvolution:
+        return ConformerConvolution(
+            self.config.model_dim,
+            self.config.depthwise_conv_kernel_size,
+            causal_depthwise_conv=self.config.causal_depthwise_conv,
+            norm_type=self.config.conv_norm_type,
+            device=self.device,
+            dtype=self.dtype,
+        )

+ 77 - 0
src/seamless_communication/models/wav2vec2_chunk/chunk_attention_mask.py

@@ -0,0 +1,77 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor
+
+from fairseq2.nn.utils.mask import to_float_mask
+
+
+class ChunkAttentionMaskGenerator:
+    """Generates a chunk attention mask for self attention.
+
+    .. note::
+        This class follows the :class:`AttentionMaskGenerator` protocol.
+    """
+
+    def __init__(
+        self, chunk_size: int, left_chunk_num: int, right_chunk_num: int
+    ) -> None:
+        self.chunk_size = chunk_size
+        self.left_chunk_num = left_chunk_num
+        self.right_chunk_num = right_chunk_num
+
+        if self.right_chunk_num != 0:
+            raise ValueError("We currently only support `right_chunk_num` == 0.")
+
+    def __call__(self, seqs: Tensor) -> Tensor:
+        """
+        :param seqs:
+            The sequences for which to generate the mask. *Shape:*
+            :math:`(N,S,M)`, where :math:`N` is the batch size, :math:`S` is the
+            sequence length, and :math:`M` is the dimensionality of the model.
+
+        :returns:
+            A chunk attention float mask for ``seqs``.
+            *Shape:* :math:`(S,S)`, where :math:`S` is the
+            sequence length.
+        """
+
+        seq_len = seqs.size(1)
+
+        chunk_indices = torch.div(
+            torch.arange(seq_len, device=seqs.device), self.chunk_size
+        ).long()
+
+        start_indices = (
+            ((chunk_indices - self.left_chunk_num).clamp_(min=0) * self.chunk_size).to(
+                seqs.device
+            )
+            if self.left_chunk_num >= 0
+            else torch.full_like(chunk_indices, 0)
+        )
+        start_indices = start_indices.unsqueeze(1).expand(-1, seq_len)
+
+        end_indices = (
+            ((chunk_indices + 1) * self.chunk_size).clamp_(max=seq_len).to(seqs.device)
+        )
+
+        end_indices = end_indices.unsqueeze(1).expand(-1, seq_len)
+
+        indices = (
+            torch.arange(seq_len, device=seqs.device).unsqueeze(0).expand(seq_len, -1)
+        )
+
+        bool_mask = (indices < start_indices) | (indices >= end_indices)
+
+        mask = to_float_mask(bool_mask, seqs.dtype)
+
+        mask = mask[:seq_len, :seq_len]
+
+        return mask
+
+    def __repr__(self) -> str:
+        return "ChunkAttentionMaskGenerator"

+ 110 - 0
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -0,0 +1,110 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Iterable, Optional, Tuple, final
+
+from torch import Tensor
+from torch.nn import Dropout
+
+from fairseq2.nn.utils.module import check_model_dim
+from fairseq2.nn.module_list import ModuleList
+from fairseq2.nn.normalization import LayerNorm
+
+from fairseq2.nn.transformer import (
+    AttentionMaskGenerator,
+    EncoderLayerOutputHook,
+    TransformerEncoder,
+    TransformerEncoderLayer,
+)
+
+from seamless_communication.models.wav2vec2_chunk.chunk_attention_mask import (
+    ChunkAttentionMaskGenerator,
+)
+
+from fairseq2.typing import finaloverride
+
+
+@final
+class ChunkTransformerEncoder(TransformerEncoder):
+    """Represents a Chunk Transformer encoder."""
+
+    preliminary_dropout: Optional[Dropout]
+    self_attn_mask_gen: AttentionMaskGenerator
+    layers: ModuleList
+    layer_norm: Optional[LayerNorm]
+
+    def __init__(
+        self,
+        layers: Iterable[TransformerEncoderLayer],
+        chunk_size: int,
+        left_chunk_num: int,
+        right_chunk_num: int,
+        *,
+        dropout_p: float = 0.0,
+        layer_drop_p: float = 0.0,
+    ) -> None:
+        """
+        :param layers:
+            The encoder layers.
+        :param chunk_size:
+            Size of each chunk.
+        :param left_chunk_num:
+            Number of chunks on the left up to which lookahead is allowed.
+        :param right_chunk_num:
+            Number of chunks on the right up to which lookahead is allowed.
+        :param dropout_p:
+            Used in the preliminary dropout.
+        :param layer_drop_p:
+            If greater than zero, applies LayerDrop to the encoder layers as
+            described in :cite:t:`https://doi.org/10.48550/arxiv.1909.11556`.
+        """
+        layer_list = ModuleList(layers, drop_p=layer_drop_p)
+        if not layer_list:
+            raise ValueError("`layers` must be non-empty.")
+
+        model_dim = layer_list[0].model_dim
+
+        super().__init__(model_dim)
+
+        if dropout_p > 0.0:
+            self.preliminary_dropout = Dropout(dropout_p)
+        else:
+            self.register_module("preliminary_dropout", None)
+
+        self.self_attn_mask_gen = ChunkAttentionMaskGenerator(
+            chunk_size * 2, left_chunk_num, right_chunk_num
+        )
+
+        self.layers = layer_list
+
+        check_model_dim(self)
+
+    @finaloverride
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[Tensor],
+        *,
+        layer_output_hook: Optional[EncoderLayerOutputHook] = None,
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        if layer_output_hook is not None and self.layers.drop_p > 0.0:
+            raise ValueError("`layer_hook` must be `None` when LayerDrop is enabled.")
+
+        if self.preliminary_dropout is not None:
+            seqs = self.preliminary_dropout(seqs)
+
+        self_attn_mask = self.self_attn_mask_gen(seqs)
+
+        num_layers = len(self.layers)
+
+        for layer_idx, layer in enumerate(self.layers.drop_iter()):
+            seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask)
+
+            if layer_output_hook is not None:
+                if not layer_output_hook(layer_idx, seqs, padding_mask, num_layers):
+                    break
+
+        return seqs, padding_mask