瀏覽代碼

Adds asset for `m4t_v2_x2t` model and makes translator compatible with x2t models, . (#47)

* Makes translator compatible with x2t models, add asset for m4t_v2_x2t.

* Get unit_extraction pipeline working, make it compatible with fairseq2 changes.

* Addressing nit comments.

---------

Co-authored-by: Can Balioglu <cbalioglu@users.noreply.github.com>
Kaushik Ram Sadagopan 1 年之前
父節點
當前提交
cb9096fed3

+ 6 - 6
scripts/m4t/finetune/trainer.py

@@ -91,14 +91,14 @@ class UnitYFinetuneWrapper(nn.Module):
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_s2t else dummy_context:  # type:ignore
             assert batch.speech_to_text.src_tokens is not None
-            seqs=batch.speech_to_text.src_tokens.to(self.device)
-            seq_lens=batch.speech_to_text.src_lengths.to(self.device)
+            seqs = batch.speech_to_text.src_tokens.to(self.device)
+            seq_lens = batch.speech_to_text.src_lengths.to(self.device)
             speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
                 seqs=seqs, padding_mask=PaddingMask(seq_lens, seqs.size(1))
             )
             assert batch.speech_to_text.prev_output_tokens is not None
-            seqs=batch.speech_to_text.prev_output_tokens.to(self.device)
-            seq_lens=batch.speech_to_text.target_lengths.to(self.device)
+            seqs = batch.speech_to_text.prev_output_tokens.to(self.device)
+            seq_lens = batch.speech_to_text.target_lengths.to(self.device)
             text_decoder_out, text_decoder_padding_mask = self.model.decode(
                 seqs=seqs,
                 padding_mask=PaddingMask(seq_lens, seqs.size(1)),
@@ -117,8 +117,8 @@ class UnitYFinetuneWrapper(nn.Module):
                 text_decoder_output=text_decoder_out,
                 text_decoder_padding_mask=text_decoder_padding_mask,
             )
-            seqs=batch.text_to_units.prev_output_tokens.to(self.device)
-            seq_lens=batch.text_to_units.target_lengths.to(self.device)
+            seqs = batch.text_to_units.prev_output_tokens.to(self.device)
+            seq_lens = batch.text_to_units.target_lengths.to(self.device)
             unit_decoder_out, _ = self.model.t2u_model.decode(
                 seqs=seqs,
                 padding_mask=PaddingMask(seq_lens, seqs.size(1)),

+ 10 - 0
src/seamless_communication/assets/cards/m4t_v2_x2t.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: m4t_v2_x2t
+base: unity_nllb-100
+model_arch: m4t_v2_x2t
+checkpoint: "file://large_experiments/seamless/ust/elbayadm/multitasking_models/m4t_v2_x2t.pt"

+ 6 - 3
src/seamless_communication/models/inference/translator.py

@@ -4,7 +4,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -15,7 +15,7 @@ from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
 from fairseq2.memory import MemoryBlock
-from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
+from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 from enum import Enum, auto
@@ -65,7 +65,10 @@ class Translator(nn.Module):
             load_unity_model, model_name_or_card, device, dtype
         )
         self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
-        self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
+        if self.model.t2u_model is not None:
+            self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
+        else:
+            self.unit_tokenizer = None
         self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.convert_to_fbank = WaveformToFbankConverter(

+ 16 - 11
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -7,17 +7,20 @@
 from typing import Union
 from pathlib import Path
 import torch
+import torch.nn.functional as F
 
 from itertools import groupby
-from fairseq2.typing import DataType, Device
 from torch import Tensor, nn
-from fairseq2.data.audio import AudioDecoder
+
+from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater
-import torch.nn.functional as F
+from fairseq2.data.audio import AudioDecoder
 from fairseq2.memory import MemoryBlock
-from fairseq2.assets.card import AssetCard
+from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.models.sequence import SequenceBatch
 from fairseq2.models.wav2vec2 import Wav2Vec2Model
+from fairseq2.typing import DataType, Device
+
 from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
     load_wav2vec2_model,
     Wav2Vec2LayerOutputModel,
@@ -38,10 +41,11 @@ class UnitExtractor(nn.Module):
         dtype: DataType = torch.float32,
     ):
         super().__init__()
-        self.wav2vec2_model: Wav2Vec2Model = Translator.load_model_for_inference(
+        wav2vec2_model = Translator.load_model_for_inference(
             load_wav2vec2_model, model_name_or_card, device, dtype
         )
-        self.model = Wav2Vec2LayerOutputModel(self.wav2vec2_model)
+        assert isinstance(wav2vec2_model, Wav2Vec2Model)
+        self.model = Wav2Vec2LayerOutputModel(wav2vec2_model)
         self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.collate = Collater(pad_idx=2, pad_to_multiple=2)
@@ -65,10 +69,10 @@ class UnitExtractor(nn.Module):
                 "format": -1,
             }
         src = self.collate(decoded_audio)["waveform"]
-        x = src["seqs"]
-        x = x.view(1, -1)
-        x = F.layer_norm(x, x.shape)
-        batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
+        seqs, padding_mask = get_seqs_and_padding_mask(src)
+        seqs = seqs.view(1, -1)
+        seqs = F.layer_norm(seqs, seqs.shape)
+        batch = SequenceBatch(seqs=seqs, padding_mask=padding_mask)
         features = self.model(batch, out_layer_idx).squeeze(0)
         units = self.kmeans_model(features)
         return units
@@ -85,8 +89,9 @@ class UnitExtractor(nn.Module):
 
         reduced_units = reduce_list(units.cpu().tolist())
 
-        vocoder: Vocoder = Translator.load_model_for_inference(
+        vocoder = Translator.load_model_for_inference(
             load_vocoder_model, vocoder_name, device, torch.float32
         )
+        assert isinstance(vocoder, Vocoder)
         wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
         return wav

+ 1 - 1
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -130,7 +130,7 @@ class Wav2Vec2LayerOutputModel(nn.Module):
 
             return True
 
-        _, _ = self.encoder(seqs, padding_mask, layer_output_hook)
+        _, _ = self.encoder(seqs, padding_mask, layer_output_hook=layer_output_hook)
 
         assert w2v2_layer_output is not None
         return w2v2_layer_output

+ 12 - 4
src/seamless_communication/models/unity/adaptor_block.py

@@ -72,7 +72,9 @@ class UnitYEncoderAdaptor(TransformerEncoder):
         self.inner = inner
 
         if inner_layer_norm:
-            self.inner_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
+            self.inner_layer_norm = layer_norm_factory(
+                model_dim, device=device, dtype=dtype
+            )
         else:
             self.register_module("inner_layer_norm", None)
 
@@ -191,7 +193,9 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         self.kernel_size = kernel_size
         self.stride = stride
 
-        self.residual_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
+        self.residual_layer_norm = layer_norm_factory(
+            model_dim, device=device, dtype=dtype
+        )
 
         self.residual_conv = Conv1d(
             model_dim,
@@ -205,7 +209,9 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
 
         self.residual_activation = GLU(dim=1)
 
-        self.self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
+        self.self_attn_layer_norm = layer_norm_factory(
+            model_dim, device=device, dtype=dtype
+        )
 
         self.self_attn_conv = Conv1d(
             model_dim,
@@ -373,7 +379,9 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
         self.stride = stride
 
         if layer_norm:
-            self.layer_norm = layer_norm_factory(self.model_dim, device=device, dtype=dtype)
+            self.layer_norm = layer_norm_factory(
+                self.model_dim, device=device, dtype=dtype
+            )
         else:
             self.register_module("layer_norm", None)
 

+ 25 - 0
src/seamless_communication/models/unity/builder.py

@@ -139,6 +139,31 @@ def _medium() -> UnitYConfig:
     )
 
 
+@unity_arch("m4t_v2_x2t")
+def _m4t_v2_x2t() -> UnitYConfig:
+    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
+
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
+
+    mt_model_config.vocabulary_size = 256102  # NLLB-100
+
+    mt_model_config.max_seq_len = 4096
+
+    return UnitYConfig(
+        model_dim=1024,
+        w2v2_encoder_config=w2v2_chunk_encoder_config,
+        mt_model_config=mt_model_config,
+        t2u_config=None,
+        use_text_encoder=True,
+        use_conformer_adaptor=False,
+        num_adaptor_layers=1,
+        adaptor_kernel_size=8,
+        adaptor_stride=8,
+        adaptor_layer_norm=True,
+        adaptor_dropout_p=0.0,
+    )
+
+
 @unity_arch("m4t_v2_s2t")
 def _m4t_v2_s2t() -> UnitYConfig:
     w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")

+ 11 - 9
src/seamless_communication/models/unity/generator.py

@@ -82,13 +82,6 @@ class UnitYGenerator:
         :param unit_generator_opts:
             The options to pass to the underlying unit :class:`Seq2SeqGenerator`.
         """
-        if model.t2u_model is None:
-            raise ValueError(
-                "`model` does not have a T2U sub-model. "
-                "For text generation only, "
-                "use `SequenceToTextGenerator` instead."
-            )
-
         model.eval()
 
         self.model = model
@@ -126,6 +119,11 @@ class UnitYGenerator:
         self.unit_decoder = None
         # Set up unit generator.
         if unit_tokenizer is not None:
+            if model.t2u_model is None:
+                raise ValueError(
+                    "`model` does not have a T2U sub-model when `unit_tokenizer` is not None."
+                )
+
             self.unit_decoder = unit_tokenizer.create_decoder()
 
             unit_encoder = unit_tokenizer.create_encoder(
@@ -175,9 +173,13 @@ class UnitYGenerator:
         """
 
         if input_modality == "speech":
-            text_output = self.s2t_generator.generate_ex(source_seqs, source_padding_mask)
+            text_output = self.s2t_generator.generate_ex(
+                source_seqs, source_padding_mask
+            )
         elif input_modality == "text" and self.t2t_generator is not None:
-            text_output = self.t2t_generator.generate_ex(source_seqs, source_padding_mask)
+            text_output = self.t2t_generator.generate_ex(
+                source_seqs, source_padding_mask
+            )
         elif input_modality == "text" and self.t2t_generator is None:
             raise ValueError(
                 f"Please set use_text_encoder to True in your model config to encode text."

+ 117 - 103
src/seamless_communication/models/unity/loader.py

@@ -48,20 +48,29 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
         keys_to_delete = []
 
-        # Use the built-in version attribute of `torch.Module`.
-        if config.t2u_config is None:
-            keys_to_delete.append("decoder.version")
-            keys_to_delete.append("decoder.embed_positions._float_tensor")
+        # X2T/S2T + T2U model.
+        if config.t2u_config is not None:
+            encoder_key = "encoder"
+            decoder_key = "target_letter_decoder"
+        # X2T model.
+        elif config.use_text_encoder:
+            encoder_key = "speech_encoder"
+            decoder_key = "shared_decoder"
+        # S2T model.
         else:
-            keys_to_delete.append("target_letter_decoder.version")
-            keys_to_delete.append("target_letter_decoder.embed_positions._float_tensor")
+            encoder_key = "encoder"
+            decoder_key = "decoder"
+
+        # Use the built-in version attribute of `torch.Module`.
+        keys_to_delete.append(f"{decoder_key}.version")
+        keys_to_delete.append(f"{decoder_key}.embed_positions._float_tensor")
 
         if config.use_text_encoder:
             keys_to_delete.append("text_encoder.version")
             keys_to_delete.append("text_encoder.embed_positions._float_tensor")
 
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
-        keys_to_delete.append("encoder.w2v_encoder.w2v_model.mask_emb")
+        keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
 
         # Delete AlignmentEncoder keys for inference.
         alignment_encoder_keys = [
@@ -126,46 +135,59 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
     @staticmethod
     def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
+        # X2T/S2T + T2U model.
+        if config.t2u_config is not None:
+            encoder_key = "encoder"
+            decoder_key = "target_letter_decoder"
+        # X2T model.
+        elif config.use_text_encoder:
+            encoder_key = "speech_encoder"
+            decoder_key = "shared_decoder"
+        # S2T model.
+        else:
+            encoder_key = "encoder"
+            decoder_key = "decoder"
+
         key_map = {
             # fmt: off
 
             # Speech Encoder
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    r"speech_encoder_frontend.pos_encoder.conv.",
-            r"^encoder\.w2v_encoder\.w2v_model\.layer_norm\.":                                              r"speech_encoder_frontend.post_extract_layer_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       r"speech_encoder_frontend.model_dim_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
-            r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
-
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     r"speech_encoder.inner.layers.\1.conv.layer_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.inner.layers.\1.layer_norm.",
-            r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     r"speech_encoder.inner.layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.pos_conv\.0\.":                                    r"speech_encoder_frontend.pos_encoder.conv.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.layer_norm\.":                                              r"speech_encoder_frontend.post_extract_layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.post_extract_proj\.":                                       r"speech_encoder_frontend.model_dim_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.0\.":             r"speech_encoder_frontend.feature_extractor.layers.\1.conv.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":          r"speech_encoder_frontend.feature_extractor.layers.\1.layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.feature_extractor\.conv_layers\.0\.2\.":                    r"speech_encoder_frontend.feature_extractor.layers.0.group_norm.",
+
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.inner.layers.\1.conv.batch_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":     r"speech_encoder.inner.layers.\1.conv.layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.inner.layers.\1.conv.depthwise_conv.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.inner.layers.\1.conv_layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv1.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.inner.layers.\1.conv.pointwise_conv2.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.inner.layers.\1.ffn\2_layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.inner.layers.\1.ffn\2.inner_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.inner.layers.\1.ffn\2.output_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.inner.layers.\1.self_attn_layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.":          r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.":          r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.":          r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.":        r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.q_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.k_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.":            r"speech_encoder.inner.layers.\1.self_attn.v_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":   r"speech_encoder.inner.layers.\1.self_attn.sdpa.rel_k_embed.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.inner.layers.\1.self_attn.output_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.":        r"speech_encoder.inner.layers.\1.self_attn.sdpa.r_proj.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.u_bias",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v":          r"speech_encoder.inner.layers.\1.self_attn.sdpa.v_bias",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.inner.layers.\1.layer_norm.",
+            fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.":                                     r"speech_encoder.inner.layer_norm.",
 
             # Speech Encoder Adaptor
-            r"^encoder\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
-            r"^encoder\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
-            r"^encoder\.adaptor\.out_ln\.":  r"speech_encoder.layer_norm.",
+            fr"^{encoder_key}\.adaptor\.proj\.0\.": r"speech_encoder.proj1.",
+            fr"^{encoder_key}\.adaptor\.proj\.2\.": r"speech_encoder.proj2.",
+            fr"^{encoder_key}\.adaptor\.out_ln\.":  r"speech_encoder.layer_norm.",
 
             # Text Encoder
             r"^text_encoder\.embed_tokens\.":                              r"text_encoder_frontend.embed.",
@@ -188,90 +210,82 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # a redundant `LayerNorm` right after the Conformer blocks. We mitigate
         # that issue here by moving that `LayerNorm` to the adaptor block.
         if config.w2v2_encoder_config.use_conformer:
+            # fmt: off
             key_map.update(
                 {
-                    r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
+                    fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
                 }
             )
         else:
             key_map.update(
                 {
-                    r"^encoder\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
+                    rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
                 }
             )
+            # fmt: on
 
-        # fmt: off
         if config.use_conformer_adaptor:
             key_map.update(
                 {
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    r"speech_encoder.adaptor_layers.\1.block.self_attn.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      r"speech_encoder.adaptor_layers.\1.layer_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 r"speech_encoder.adaptor_layers.\1.conv.",
+                    # fmt: off
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":          r"speech_encoder.adaptor_layers.\1.block.self_attn.output_proj.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":                    r"speech_encoder.adaptor_layers.\1.block.self_attn.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.self_attn_layer_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":         r"speech_encoder.adaptor_layers.\1.block.ffn\2_layer_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.inner_proj.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                r"speech_encoder.adaptor_layers.\1.block.ffn\2.output_proj.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.batch_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv.batch_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":  r"speech_encoder.adaptor_layers.\1.block.conv.depthwise_conv.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.layer_norm\.":      r"speech_encoder.adaptor_layers.\1.block.conv_layer_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv1.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"speech_encoder.adaptor_layers.\1.block.conv.pointwise_conv2.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":             r"speech_encoder.adaptor_layers.\1.block.layer_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_ln\.":                      r"speech_encoder.adaptor_layers.\1.layer_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.conv_pool\.1\.":                 r"speech_encoder.adaptor_layers.\1.conv.",
+                    # fmt: on
                 }
             )
         else:
             key_map.update(
                 {
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     r"speech_encoder.adaptor_layers.\1.residual_conv.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn\.":            r"speech_encoder.adaptor_layers.\1.self_attn.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.fc1\.":                  r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.fc2\.":                  r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
-                    r"^encoder\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
+                    # fmt: off
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_layer_norm\.":  r"speech_encoder.adaptor_layers.\1.residual_layer_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.residual_pool\.1\.":     r"speech_encoder.adaptor_layers.\1.residual_conv.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.attn_pool\.1\.":         r"speech_encoder.adaptor_layers.\1.self_attn_conv.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.out_proj\.":  r"speech_encoder.adaptor_layers.\1.self_attn.output_proj.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn\.":            r"speech_encoder.adaptor_layers.\1.self_attn.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"speech_encoder.adaptor_layers.\1.self_attn_layer_norm.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc1\.":                  r"speech_encoder.adaptor_layers.\1.ffn.inner_proj.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.fc2\.":                  r"speech_encoder.adaptor_layers.\1.ffn.output_proj.",
+                    fr"^{encoder_key}\.adaptor\.layers\.([0-9]+)\.final_layer_norm\.":     r"speech_encoder.adaptor_layers.\1.ffn_layer_norm.",
+                    # fmt: on
                 }
             )
 
-        # S2T model.
-        if config.t2u_config is None:
-            key_map.update(
-                {
-                    # Text Decoder
-                    r"^decoder\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
-                    r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
-                    r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
-                    r"^decoder\.layer_norm\.":                                r"text_decoder.layer_norm.",
-                    r"^decoder\.output_projection\.":                         r"final_proj.",
-                }
-            )
-        # S2T + T2U model.
-        else:
+        key_map.update(
+            {
+                # fmt: off
+                # Text Decoder
+                fr"^{decoder_key}\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
+                fr"^{decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
+                fr"^{decoder_key}\.layer_norm\.":                                r"text_decoder.layer_norm.",
+                fr"^{decoder_key}\.output_projection\.":                         r"final_proj.",
+                # fmt: on
+            }
+        )
+        # X2T/S2T + T2U model.
+        if config.t2u_config is not None:
             key_map.update(
                 {
-                    # Text Decoder
-                    r"^target_letter_decoder\.embed_tokens\.":                              r"text_decoder_frontend.embed.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"text_decoder.layers.\1.self_attn.output_proj.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn\.":               r"text_decoder.layers.\1.self_attn.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"text_decoder.layers.\1.self_attn_layer_norm.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"text_decoder.layers.\1.encoder_decoder_attn.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.fc1\.":                     r"text_decoder.layers.\1.ffn.inner_proj.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.fc2\.":                     r"text_decoder.layers.\1.ffn.output_proj.",
-                    r"^target_letter_decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"text_decoder.layers.\1.ffn_layer_norm.",
-                    r"^target_letter_decoder\.layer_norm\.":                                r"text_decoder.layer_norm.",
-                    r"^target_letter_decoder\.output_projection\.":                         r"final_proj.",
-
+                    # fmt: off
                     # T2U Encoder
                     r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
                     r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
@@ -305,9 +319,9 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
                     r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
                     r"^decoder\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
                     r"^decoder\.output_projection\.":                         r"t2u_model.final_proj.",
+                    # fmt: on
                 }
             )
-        # fmt: on
 
         return key_map
 

+ 9 - 3
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -151,7 +151,9 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
 
         layer_norm_factory = create_standard_layer_norm
 
-        self.self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
+        self.self_attn_layer_norm = layer_norm_factory(
+            model_dim, device=device, dtype=dtype
+        )
 
         self.conv1d = conv1d
 
@@ -160,7 +162,9 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
         else:
             self.register_module("conv1d_dropout", None)
 
-        self.conv1d_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
+        self.conv1d_layer_norm = layer_norm_factory(
+            model_dim, device=device, dtype=dtype
+        )
 
         check_model_dim(self)
 
@@ -204,7 +208,9 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
 
         return seqs
 
-    def _forward_conv1d(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
+    def _forward_conv1d(
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
+    ) -> Tensor:
         residual = seqs
 
         seqs = self.conv1d(seqs, padding_mask)

+ 0 - 1
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -15,7 +15,6 @@ from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 
 from fairseq2.nn.transformer import (
-    AttentionMaskFactory,
     EncoderLayerOutputHook,
     TransformerEncoder,
     TransformerEncoderLayer,