浏览代码

Open-sourcing conformer_shaw 600m speech encoder. (#252)

* Adding asset, loader, integ test for wav2vec2_chunk_600m.

* Replace asset with public S3 path, fix mypy issue.

* Remove all mentions of right_chunk_num since it's redundant.

* Rename wav2vec2_chunk to conformer_shaw, and remove ChunkTransformerEncoder and its attn mask.

* Update README.md adding the speech encoder to the list of models.

* Update README.md to address comments, specify licensing for the speech encoder.

* Remove speech encoder model card placeholder from README.

* Check std, mean of speech encoder features in integ test.

* Update README.md with detailed instructions on how to run a forward() through the speech encoder.

* Get parity with fairseq1, drop the extra layer_norm after encoder layers.

* Fix checkpoint link in README.
Kaushik Ram Sadagopan 1 年之前
父节点
当前提交
901782b9f6

+ 46 - 1
README.md

@@ -55,7 +55,7 @@ The Seamless model is the unified model for expressive streaming speech-to-speec
 | HuggingFace Space Demo | [🤗 SeamlessM4T v2 Space](https://huggingface.co/spaces/facebook/seamless-m4t-v2-large)                                                | [🤗 SeamlessExpressive Space](https://huggingface.co/spaces/facebook/seamless-expressive)                                                         | [🤗 SeamlessStreaming Space](https://huggingface.co/spaces/facebook/seamless-streaming) |
 
 ## What's new
-
+- [12/18/2023] We are open-sourcing our Conformer-based [W2v-BERT 2.0 speech encoder](#w2v-bert-20-speech-encoder) as described in Section 3.2.1 of the [paper](https://arxiv.org/pdf/2312.05187.pdf), which is at the core of our Seamless models.
 
 
 # Quick Start
@@ -158,6 +158,50 @@ Please note that SeamlessExpressive is made available under its own [License](SE
 Seamless model is simply the SeamlessStreaming model with the non-expressive `vocoder_v2` swapped out with the expressive `vocoder_pretssel`.
 Please check out above [section](#seamlessexpressive-models) on how to acquire `vocoder_pretssel` checkpoint.
 
+### W2v-BERT 2.0 speech encoder
+| Model Name        | #params | checkpoint                                                                                                                                                                                                                                                                                                                                                                 |
+| ----------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| W2v-BERT 2.0 | 600M    | [checkpoint](https://dl.fbaipublicfiles.com/seamless/models/conformer_shaw.pt)
+
+Here's how you should do a foward pass through the speech encoder:
+
+```python
+import torch
+
+from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
+from fairseq2.memory import MemoryBlock
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+from pathlib import Path
+from seamless_communication.models.conformer_shaw import load_conformer_shaw_model
+
+
+audio_wav_path, device, dtype = ...
+audio_decoder = AudioDecoder(dtype=torch.float32, device=device)
+fbank_converter = WaveformToFbankConverter(
+    num_mel_bins=80,
+    waveform_scale=2**15,
+    channel_last=True,
+    standardize=True,
+    device=device,
+    dtype=dtype,
+)
+collater = Collater(pad_value=1)
+
+model = load_conformer_shaw_model("conformer_shaw", device=device, dtype=dtype)
+model.eval()
+
+with Path(audio_wav_path).open("rb") as fb:
+    block = MemoryBlock(fb.read())
+
+decoded_audio = audio_decoder(block)
+src = collater(fbank_converter(decoded_audio))["fbank"]
+seqs, padding_mask = get_seqs_and_padding_mask(src)
+
+with torch.inference_mode():
+  seqs, padding_mask = model.encoder_frontend(seqs, padding_mask)
+  seqs, padding_mask = model.encoder(seqs, padding_mask)
+```
+
 ## Evaluation
 
 ### SeamlessM4T Evaluation
@@ -237,6 +281,7 @@ If you use Seamless in your work or any models/datasets/artifacts published in S
 We have three license categories.
 
 The following non-generative components are MIT licensed as found in [MIT_LICENSE](MIT_LICENSE):
+- [W2v-BERT 2.0 speech encoder](#w2v-bert-20-speech-encoder)
 - Code
 - Text only part of the mExpresso dataset found in the [SeamlessExpressive README](docs/expressive/README.md).
 - UnitY2 forced alignment extractor found in the [UnitY2 Aligner README](docs/m4t/unity2_aligner_README.md).

+ 10 - 0
src/seamless_communication/cards/conformer_shaw.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+name: conformer_shaw
+model_type: wav2vec2
+model_arch: conformer_shaw_600m
+checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/conformer_shaw.pt"

+ 21 - 0
src/seamless_communication/models/conformer_shaw/__init__.py

@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+from seamless_communication.models.conformer_shaw.builder import (
+    ConformerShawEncoderBuilder as ConformerShawEncoderBuilder,
+)
+from seamless_communication.models.conformer_shaw.builder import (
+    ConformerShawEncoderConfig as ConformerShawEncoderConfig,
+)
+from seamless_communication.models.conformer_shaw.builder import (
+    conformer_shaw_archs as conformer_shaw_archs,
+)
+from seamless_communication.models.conformer_shaw.builder import (
+    create_conformer_shaw_model as create_conformer_shaw_model,
+)
+from seamless_communication.models.conformer_shaw.loader import (
+    load_conformer_shaw_model as load_conformer_shaw_model,
+)

+ 182 - 0
src/seamless_communication/models/conformer_shaw/builder.py

@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+from dataclasses import asdict, dataclass
+from typing import Optional
+
+from fairseq2.models.conformer import ConformerConvolution
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.w2vbert import w2vbert_archs
+from fairseq2.models.wav2vec2.builder import (
+    Wav2Vec2Builder,
+    Wav2Vec2Config,
+    Wav2Vec2EncoderBuilder,
+    Wav2Vec2EncoderConfig,
+    wav2vec2_arch,
+)
+from fairseq2.models.wav2vec2.model import Wav2Vec2Model
+from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA, create_default_sdpa
+from fairseq2.typing import DataType, Device
+
+
+@dataclass
+class ShawRelativePositionSDPAConfig:
+    """Holds the configuration of the :class:ShawRelativePositionSDPA module."""
+
+    max_left_rel_pos: int
+    """The left clipping value for relative positions."""
+
+    max_right_rel_pos: Optional[int]
+    """The right clipping value for relative positions."""
+
+    use_rel_pos_values: bool = False
+    """If True, also uses relative position values to compute relative attention."""
+
+
+@dataclass
+class ConformerShawEncoderConfig(Wav2Vec2EncoderConfig):
+    """Holds the configuration of a conformer shaw encoder."""
+
+    shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
+    """The parameters for ShawRelativePositionSDPA."""
+
+
+conformer_shaw_archs = ArchitectureRegistry[ConformerShawEncoderConfig](
+    "conformer_shaw"
+)
+
+conformer_shaw_arch = conformer_shaw_archs.decorator
+
+
+@conformer_shaw_arch("600m")
+def _conformer_shaw_600m_encoder() -> ConformerShawEncoderConfig:
+    w2vbert_config = w2vbert_archs.get_config("600m")
+    w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
+    sdpa_config = ShawRelativePositionSDPAConfig(
+        max_left_rel_pos=64,
+        max_right_rel_pos=8,
+        use_rel_pos_values=False,
+    )
+    conformer_shaw_encoder_config = ConformerShawEncoderConfig(
+        **asdict(w2v2_encoder_config),
+        shaw_rel_pos_sdpa_config=sdpa_config,
+    )
+    conformer_shaw_encoder_config.pos_encoder_type = "shaw_relative"
+    return conformer_shaw_encoder_config
+
+
+@wav2vec2_arch("conformer_shaw_600m")
+def _conformer_shaw_600m() -> Wav2Vec2Config:
+    encoder_config = _conformer_shaw_600m_encoder()
+
+    return Wav2Vec2Config(
+        encoder_config,
+        final_dim=768,
+        final_proj_bias=True,
+        temporal_mask_span_len=10,
+        max_temporal_mask_prob=0.65,
+        spatial_mask_span_len=10,
+        max_spatial_mask_prob=0.0,
+        quantized_dim=768,
+        num_codebooks=2,
+        num_codebook_entries=320,
+        codebook_sampling_temperature=(2.0, 0.1, 0.999995),
+        num_distractors=100,
+        logit_temp=0.1,
+        diversity_loss_weight=0.2,
+    )
+
+
+class ConformerShawEncoderBuilder(Wav2Vec2EncoderBuilder):
+    """
+    Builds modules of a `ConformerShawEncoderBuilder`.
+
+    This is a Conformer architecture with these differences:
+    - ShawRelativePositionSDPA as the SDPA.
+    - ConformerConvolution with causal depthwise convolution
+    and norm_type "layer_norm".
+    """
+
+    config: ConformerShawEncoderConfig
+
+    def __init__(
+        self,
+        config: ConformerShawEncoderConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        super().__init__(config, device=device, dtype=dtype)
+
+        assert self.config.use_conformer, "This architecture only supports a Conformer."
+        assert (
+            self.config.pos_encoder_type == "shaw_relative"
+        ), "This architecture only supports ShawRelativePositionSDPA."
+
+    def build_sdpa(self) -> SDPA:
+        if self.config.shaw_rel_pos_sdpa_config is None:
+            raise ValueError(
+                "`shaw_rel_pos_sdpa_config` must be specified when `pos_encoder_type` is 'shaw_relative'."
+            )
+
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.attn_dropout_p)
+
+        sdpa_config = self.config.shaw_rel_pos_sdpa_config
+
+        return ShawRelativePositionSDPA(
+            self.config.model_dim,
+            self.config.num_encoder_attn_heads,
+            sdpa_config.max_left_rel_pos,
+            max_right_rel_pos=sdpa_config.max_right_rel_pos,
+            use_rel_pos_values=sdpa_config.use_rel_pos_values,
+            inner_sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_conformer_conv(self) -> ConformerConvolution:
+        return ConformerConvolution(
+            self.config.model_dim,
+            self.config.depthwise_conv_kernel_size,
+            causal_depthwise_conv=True,
+            norm_type="layer_norm",
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+def create_conformer_shaw_model(
+    config: Wav2Vec2Config,
+    *,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> Wav2Vec2Model:
+    """Create a conformer shaw model.
+
+    :param config:
+        The configuration.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+    assert isinstance(config.encoder_config, ConformerShawEncoderConfig)
+
+    encoder_builder = ConformerShawEncoderBuilder(
+        config.encoder_config, device=device, dtype=dtype
+    )
+
+    builder = Wav2Vec2Builder(config, encoder_builder, device=device, dtype=dtype)
+
+    return builder.build_model()

+ 82 - 0
src/seamless_communication/models/conformer_shaw/loader.py

@@ -0,0 +1,82 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+from typing import Any, Mapping
+
+import torch
+
+from fairseq2.assets import asset_store, download_manager
+from fairseq2.models.utils import ModelLoader
+from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
+from fairseq2.models.wav2vec2.builder import Wav2Vec2Config
+from fairseq2.models.wav2vec2.loader import load_wav2vec2_config
+from fairseq2.models.wav2vec2.model import Wav2Vec2Model
+
+from seamless_communication.models.conformer_shaw.builder import (
+    create_conformer_shaw_model,
+)
+
+
+def convert_conformer_shaw_checkpoint(
+    checkpoint: Mapping[str, Any], config: Wav2Vec2Config
+) -> Mapping[str, Any]:
+    """Convert a fairseq conformer shaw checkpoint to fairseq2."""
+    state_dict = checkpoint["model"]
+
+    # Check if we have a fairseq2 checkpoint.
+    if "final_target_proj.weight" in state_dict:
+        return checkpoint
+
+    for key in (
+        "mlm_proj.weight",
+        "mlm_proj.bias",
+        "encoder.layer_norm.weight",
+        "encoder.layer_norm.bias",
+    ):
+        if key in state_dict:
+            del state_dict[key]
+
+    state_dict["quantizer.num_updates"] = torch.zeros((), device="cpu")
+
+    key_map = {
+        # fmt: off
+        r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":            r"encoder.layers.\1.self_attn.output_proj.",
+        r"^encoder\.layers\.([0-9]+)\.self_attn\.rel_k_embedding\.":     r"encoder.layers.\1.self_attn.sdpa.rel_k_embed.",
+        r"^encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.":    r"encoder.layers.\1.conv.depthwise_conv.",
+        r"^encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.":        r"encoder.layers.\1.conv_layer_norm.",
+        r"^encoder\.layers\.([0-9]+)\.conv_module\.layer_norm2\.":       r"encoder.layers.\1.conv.layer_norm.",
+        r"^encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.":   r"encoder.layers.\1.conv.pointwise_conv1.",
+        r"^encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.":   r"encoder.layers.\1.conv.pointwise_conv2.",
+        r"^encoder\.layers\.([0-9]+)\.fc1\.":                            r"encoder.layers.\1.ffn.inner_proj.",
+        r"^encoder\.layers\.([0-9]+)\.fc2\.":                            r"encoder.layers.\1.ffn.output_proj.",
+        r"^encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.":           r"encoder.layers.\1.ffn\2_layer_norm.",
+        r"^encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.":                  r"encoder.layers.\1.ffn\2.inner_proj.",
+        r"^encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.":                  r"encoder.layers.\1.ffn\2.output_proj.",
+        r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.":               r"encoder.layers.\1.layer_norm.",
+        r"^encoder\.embed_tokens\.":                                     r"encoder_frontend.embed.",
+        r"^encoder\.pos_conv\.0\.":                                      r"encoder_frontend.pos_encoder.conv.",
+        r"^feature_extractor\.conv_layers\.([0-9]+)\.0\.":               r"encoder_frontend.feature_extractor.layers.\1.conv.",
+        r"^feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.":            r"encoder_frontend.feature_extractor.layers.\1.layer_norm.",
+        r"^feature_extractor\.conv_layers\.0\.2\.":                      r"encoder_frontend.feature_extractor.layers.0.group_norm.",
+        r"^layer_norm\.":                                                r"encoder_frontend.post_extract_layer_norm.",
+        r"^post_extract_proj\.":                                         r"encoder_frontend.model_dim_proj.",
+        r"^mask_emb":                                                    r"masker.temporal_mask_embed",
+        r"^quantizer\.vars":                                             r"quantizer.entries",
+        r"^quantizer\.weight_proj\.":                                    r"quantizer.entry_proj.",
+        r"^project_q\.":                                                 r"final_target_proj.",
+        # fmt: on
+    }
+
+    return convert_fairseq_checkpoint(checkpoint, key_map)
+
+
+load_conformer_shaw_model = ModelLoader[Wav2Vec2Model, Wav2Vec2Config](
+    asset_store,
+    download_manager,
+    load_wav2vec2_config,
+    create_conformer_shaw_model,
+    convert_conformer_shaw_checkpoint,
+)

+ 1 - 1
src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -53,7 +53,7 @@ class UnitExtractor(nn.Module):
         assert isinstance(wav2vec2_model, Wav2Vec2Model)
         self.model = Wav2Vec2LayerOutputModel(wav2vec2_model)
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
-        self.collate = Collater(pad_value=2, pad_to_multiple=2)
+        self.collate = Collater(pad_value=1, pad_to_multiple=2)
         self.kmeans_model = KmeansModel(kmeans_uri, device, dtype)
         self.device = device
         self.dtype = dtype

+ 10 - 10
src/seamless_communication/models/unity/builder.py

@@ -43,10 +43,10 @@ from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UConfig,
     unity_t2u_archs,
 )
-from seamless_communication.models.wav2vec2_chunk import (
-    Wav2Vec2ChunkEncoderBuilder,
-    Wav2Vec2ChunkEncoderConfig,
-    wav2vec2_chunk_archs,
+from seamless_communication.models.conformer_shaw import (
+    ConformerShawEncoderBuilder,
+    ConformerShawEncoderConfig,
+    conformer_shaw_archs,
 )
 
 
@@ -317,7 +317,7 @@ def _nano() -> UnitYConfig:
 
 @unity_arch("base_v2")
 def _base_v2() -> UnitYConfig:
-    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
+    conformer_shaw_encoder_config = conformer_shaw_archs.get_config("600m")
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
@@ -329,7 +329,7 @@ def _base_v2() -> UnitYConfig:
 
     return UnitYConfig(
         model_dim=1024,
-        w2v2_encoder_config=w2v2_chunk_encoder_config,
+        w2v2_encoder_config=conformer_shaw_encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         prosody_encoder_config=None,
@@ -347,7 +347,7 @@ def _base_v2() -> UnitYConfig:
 
 @unity_arch("expressivity_v2")
 def _expressivity_v2() -> UnitYConfig:
-    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
+    conformer_shaw_encoder_config = conformer_shaw_archs.get_config("600m")
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
@@ -361,7 +361,7 @@ def _expressivity_v2() -> UnitYConfig:
 
     return UnitYConfig(
         model_dim=1024,
-        w2v2_encoder_config=w2v2_chunk_encoder_config,
+        w2v2_encoder_config=conformer_shaw_encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         prosody_encoder_config=prosody_encoder_config,
@@ -624,8 +624,8 @@ def create_unity_model(
     :param dtype:
         The data type of module parameters and buffers.
     """
-    if isinstance(config.w2v2_encoder_config, Wav2Vec2ChunkEncoderConfig):
-        w2v2_encoder_builder: Wav2Vec2EncoderBuilder = Wav2Vec2ChunkEncoderBuilder(
+    if isinstance(config.w2v2_encoder_config, ConformerShawEncoderConfig):
+        w2v2_encoder_builder: Wav2Vec2EncoderBuilder = ConformerShawEncoderBuilder(
             config.w2v2_encoder_config, device=device, dtype=dtype
         )
     else:

+ 0 - 15
src/seamless_communication/models/wav2vec2_chunk/__init__.py

@@ -1,15 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# MIT_LICENSE file in the root directory of this source tree.
-
-from seamless_communication.models.wav2vec2_chunk.builder import (
-    Wav2Vec2ChunkEncoderBuilder as Wav2Vec2ChunkEncoderBuilder,
-)
-from seamless_communication.models.wav2vec2_chunk.builder import (
-    Wav2Vec2ChunkEncoderConfig as Wav2Vec2ChunkEncoderConfig,
-)
-from seamless_communication.models.wav2vec2_chunk.builder import (
-    wav2vec2_chunk_archs as wav2vec2_chunk_archs,
-)

+ 0 - 161
src/seamless_communication/models/wav2vec2_chunk/builder.py

@@ -1,161 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# MIT_LICENSE file in the root directory of this source tree.
-
-from dataclasses import asdict, dataclass
-from typing import Literal, Optional
-
-from fairseq2.models.conformer import ConformerConvolution
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.models.w2vbert import w2vbert_archs
-from fairseq2.models.wav2vec2.builder import (
-    Wav2Vec2EncoderBuilder,
-    Wav2Vec2EncoderConfig,
-)
-from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA, create_default_sdpa
-from fairseq2.typing import DataType, Device
-
-from seamless_communication.models.wav2vec2_chunk.encoder import ChunkTransformerEncoder
-
-
-@dataclass
-class ShawRelativePositionSDPAConfig:
-    """Holds the configuration of the :class:ShawRelativePositionSDPA module."""
-
-    max_left_rel_pos: int
-    """The left clipping value for relative positions."""
-
-    max_right_rel_pos: Optional[int]
-    """The right clipping value for relative positions."""
-
-    use_rel_pos_values: bool = False
-    """If True, also uses relative position values to compute relative attention."""
-
-
-@dataclass
-class Wav2Vec2ChunkEncoderConfig(Wav2Vec2EncoderConfig):
-    """Holds the configuration of a wav2vec2 chunk encoder."""
-
-    causal_depthwise_conv: bool
-    """If True, uses a causal depthwise convolution similar to that described in
-    Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`."""
-
-    conv_norm_type: Literal["batch_norm", "layer_norm"]
-    """The type of normalization to use in the Conformer convolution module."""
-
-    shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
-    """The parameters for ShawRelativePositionSDPA."""
-
-    chunk_size: int
-    """The size of each chunk."""
-
-    left_chunk_num: int
-    """Number of chunks on the left up to which lookahead is allowed."""
-
-    right_chunk_num: int
-    """Number of chunks on the right up to which lookahead is allowed."""
-
-
-wav2vec2_chunk_archs = ArchitectureRegistry[Wav2Vec2ChunkEncoderConfig](
-    "wav2vec2_chunk"
-)
-
-wav2vec2_chunk_arch = wav2vec2_chunk_archs.decorator
-
-
-@wav2vec2_chunk_arch("600m")
-def _encoder_600m() -> Wav2Vec2ChunkEncoderConfig:
-    w2vbert_config = w2vbert_archs.get_config("600m")
-    w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
-    sdpa_config = ShawRelativePositionSDPAConfig(
-        max_left_rel_pos=64,
-        max_right_rel_pos=8,
-        use_rel_pos_values=False,
-    )
-    w2v2_chunk_encoder_config = Wav2Vec2ChunkEncoderConfig(
-        **asdict(w2v2_encoder_config),
-        causal_depthwise_conv=True,
-        conv_norm_type="layer_norm",
-        shaw_rel_pos_sdpa_config=sdpa_config,
-        chunk_size=10000,
-        left_chunk_num=128,
-        right_chunk_num=0,
-    )
-    w2v2_chunk_encoder_config.pos_encoder_type = "shaw_relative"
-    return w2v2_chunk_encoder_config
-
-
-class Wav2Vec2ChunkEncoderBuilder(Wav2Vec2EncoderBuilder):
-    config: Wav2Vec2ChunkEncoderConfig
-
-    def __init__(
-        self,
-        config: Wav2Vec2ChunkEncoderConfig,
-        *,
-        device: Optional[Device] = None,
-        dtype: Optional[DataType] = None,
-    ) -> None:
-        """
-        :param config:
-            The configuration to use.
-        :param device:
-            The device on which to initialize modules.
-        :param dtype:
-            The data type of module parameters and buffers.
-        """
-        super().__init__(config, device=device, dtype=dtype)
-
-        assert (
-            self.config.use_conformer
-        ), "Currently we only support the ChunkConformerBlock."
-
-    def build_encoder(self) -> ChunkTransformerEncoder:
-        """Build a Transformer encoder."""
-        num_layers = self.config.num_encoder_layers
-
-        layers = [self.build_encoder_layer() for _ in range(num_layers)]
-
-        return ChunkTransformerEncoder(
-            layers,
-            self.config.chunk_size,
-            self.config.left_chunk_num,
-            self.config.right_chunk_num,
-            dropout_p=self.config.dropout_p,
-            layer_drop_p=self.config.layer_drop_p,
-        )
-
-    def build_sdpa(self) -> SDPA:
-        if self.config.pos_encoder_type == "shaw_relative":
-            if self.config.shaw_rel_pos_sdpa_config is None:
-                raise ValueError(
-                    "`shaw_rel_pos_sdpa_config` must be specified when `pos_encoder_type` is 'shaw_relative'."
-                )
-
-            sdpa = create_default_sdpa(attn_dropout_p=self.config.attn_dropout_p)
-
-            sdpa_config = self.config.shaw_rel_pos_sdpa_config
-
-            return ShawRelativePositionSDPA(
-                self.config.model_dim,
-                self.config.num_encoder_attn_heads,
-                sdpa_config.max_left_rel_pos,
-                max_right_rel_pos=sdpa_config.max_right_rel_pos,
-                use_rel_pos_values=sdpa_config.use_rel_pos_values,
-                inner_sdpa=sdpa,
-                device=self.device,
-                dtype=self.dtype,
-            )
-
-        return super().build_sdpa()
-
-    def build_conformer_conv(self) -> ConformerConvolution:
-        return ConformerConvolution(
-            self.config.model_dim,
-            self.config.depthwise_conv_kernel_size,
-            causal_depthwise_conv=self.config.causal_depthwise_conv,
-            norm_type=self.config.conv_norm_type,
-            device=self.device,
-            dtype=self.dtype,
-        )

+ 0 - 76
src/seamless_communication/models/wav2vec2_chunk/chunk_attention_mask.py

@@ -1,76 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# MIT_LICENSE file in the root directory of this source tree.
-
-from typing import Optional
-
-import torch
-from fairseq2.nn.transformer import AttentionMask, CustomAttentionMask
-from fairseq2.nn.utils.mask import to_float_mask
-from torch import Tensor
-
-
-class ChunkAttentionMaskFactory:
-    """Generates a chunk attention mask for self attention.
-
-    .. note::
-        This class follows the :class:`AttentionMaskGenerator` protocol.
-    """
-
-    def __init__(
-        self, chunk_size: int, left_chunk_num: int, right_chunk_num: int
-    ) -> None:
-        self.chunk_size = chunk_size
-        self.left_chunk_num = left_chunk_num
-        self.right_chunk_num = right_chunk_num
-
-        if self.right_chunk_num != 0:
-            raise ValueError("We currently only support `right_chunk_num` == 0.")
-
-    def __call__(self, seqs: Tensor) -> Optional[AttentionMask]:
-        """
-        :param seqs:
-            The sequences for which to generate the mask. *Shape:*
-            :math:`(N,S,M)`, where :math:`N` is the batch size, :math:`S` is the
-            sequence length, and :math:`M` is the dimensionality of the model.
-
-        :returns:
-            A chunk attention float mask for ``seqs``.
-            *Shape:* :math:`(S,S)`, where :math:`S` is the
-            sequence length.
-        """
-
-        seq_len = seqs.size(1)
-
-        chunk_indices = torch.div(
-            torch.arange(seq_len, device=seqs.device), self.chunk_size
-        ).long()
-
-        start_indices = (
-            ((chunk_indices - self.left_chunk_num).clamp_(min=0) * self.chunk_size).to(
-                seqs.device
-            )
-            if self.left_chunk_num >= 0
-            else torch.full_like(chunk_indices, 0)
-        )
-        start_indices = start_indices.unsqueeze(1).expand(-1, seq_len)
-
-        end_indices = (
-            ((chunk_indices + 1) * self.chunk_size).clamp_(max=seq_len).to(seqs.device)
-        )
-
-        end_indices = end_indices.unsqueeze(1).expand(-1, seq_len)
-
-        indices = (
-            torch.arange(seq_len, device=seqs.device).unsqueeze(0).expand(seq_len, -1)
-        )
-
-        bool_mask = (indices < start_indices) | (indices >= end_indices)
-
-        mask = to_float_mask(bool_mask, seqs.dtype)
-
-        mask = mask[:seq_len, :seq_len]
-
-        return CustomAttentionMask(mask)

+ 0 - 98
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -1,98 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# MIT_LICENSE file in the root directory of this source tree.
-
-from typing import Iterable, Optional, Tuple, final
-
-from fairseq2.nn.module_list import ModuleList
-from fairseq2.nn.normalization import LayerNorm
-from fairseq2.nn.padding import PaddingMask
-from fairseq2.nn.transformer import TransformerEncoder, TransformerEncoderLayer
-from fairseq2.typing import finaloverride
-from torch import Tensor
-from torch.nn import Dropout
-
-from seamless_communication.models.wav2vec2_chunk.chunk_attention_mask import (
-    ChunkAttentionMaskFactory,
-)
-
-
-@final
-class ChunkTransformerEncoder(TransformerEncoder):
-    """Represents a Chunk Transformer encoder."""
-
-    preliminary_dropout: Optional[Dropout]
-    self_attn_mask_factory: ChunkAttentionMaskFactory
-    layers: ModuleList
-    layer_norm: Optional[LayerNorm]
-
-    def __init__(
-        self,
-        layers: Iterable[TransformerEncoderLayer],
-        chunk_size: int,
-        left_chunk_num: int,
-        right_chunk_num: int,
-        *,
-        dropout_p: float = 0.0,
-        layer_drop_p: float = 0.0,
-    ) -> None:
-        """
-        :param layers:
-            The encoder layers.
-        :param chunk_size:
-            Size of each chunk.
-        :param left_chunk_num:
-            Number of chunks on the left up to which lookahead is allowed.
-        :param right_chunk_num:
-            Number of chunks on the right up to which lookahead is allowed.
-        :param dropout_p:
-            Used in the preliminary dropout.
-        :param layer_drop_p:
-            If greater than zero, applies LayerDrop to the encoder layers as
-            described in :cite:t:`https://doi.org/10.48550/arxiv.1909.11556`.
-        """
-        layer_list = ModuleList(layers, drop_p=layer_drop_p)
-        if not layer_list:
-            raise ValueError("`layers` must be non-empty.")
-
-        model_dim = layer_list[0].model_dim
-
-        super().__init__(model_dim)
-
-        if dropout_p > 0.0:
-            self.preliminary_dropout = Dropout(dropout_p)
-        else:
-            self.register_module("preliminary_dropout", None)
-
-        self.self_attn_mask_factory = ChunkAttentionMaskFactory(
-            chunk_size * 2, left_chunk_num, right_chunk_num
-        )
-
-        self.layers = layer_list
-
-    @finaloverride
-    def forward(
-        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
-    ) -> Tuple[Tensor, Optional[PaddingMask]]:
-        if self._layer_output_hooks and self.layers.drop_p > 0.0:
-            raise ValueError(
-                "The layer output hooks cannot be run when LayerDrop is enabled."
-            )
-
-        if self.preliminary_dropout is not None:
-            seqs = self.preliminary_dropout(seqs)
-
-        self_attn_mask = self.self_attn_mask_factory(seqs)
-
-        num_layers = len(self.layers)
-
-        for layer_idx, layer in enumerate(self.layers.drop_iter()):
-            seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask)
-
-            for hook in self._layer_output_hooks.values():
-                if not hook(layer_idx, seqs, padding_mask, num_layers):
-                    break
-
-        return seqs, padding_mask

+ 41 - 0
tests/integration/models/test_conformer_shaw.py

@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from fairseq2.data.audio import AudioDecoderOutput
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+
+from seamless_communication.models.conformer_shaw import load_conformer_shaw_model
+
+from tests.common import (
+    convert_to_collated_fbank,
+    get_default_dtype,
+    device,
+)
+
+REF_MEAN, REF_STD = -0.0001, 0.1547
+
+
+def test_conformer_shaw_600m(example_rate16k_audio: AudioDecoderOutput) -> None:
+
+    dtype = get_default_dtype()
+    audio_dict = example_rate16k_audio
+    src = convert_to_collated_fbank(audio_dict, dtype=dtype)
+    seqs, padding_mask = get_seqs_and_padding_mask(src)
+
+    model = load_conformer_shaw_model("conformer_shaw", device=device, dtype=dtype)
+    model.eval()
+
+    with torch.inference_mode():
+        seqs, padding_mask = model.encoder_frontend(seqs, padding_mask)
+
+        seqs, _ = model.encoder(seqs, padding_mask)
+
+    std, mean = torch.std_mean(seqs)
+
+    assert round(mean.item(), 4) == REF_MEAN
+    assert round(std.item(), 4) == REF_STD