Browse Source

Initial commit.

Kaushik Ram Sadagopan 2 years ago
parent
commit
99f66387ac

+ 10 - 0
src/seamless_communication/assets/cards/xlsr2_1b_v2.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: xlsr2_1b_v2
+model_type: wav2vec2
+model_arch: xlsr2_1b_v2
+checkpoint: "file://private/home/changhan/data/models/wav2vec2/xlsr2_1b_v2.pt"

+ 5 - 0
src/seamless_communication/models/unit_extraction/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.

+ 5 - 0
src/seamless_communication/models/unit_extraction/kmeans.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.

+ 56 - 0
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -0,0 +1,56 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple, Union
+from pathlib import Path
+import torch
+
+from fairseq2.typing import Device
+from torch import Tensor, nn
+from fairseq2.data.audio import AudioDecoder
+from fairseq2.models.wav2vec2 import load_wav2vec2_model, Wav2Vec2Model
+from fairseq2.data.typing import StringLike
+from fairseq2.memory import MemoryBlock
+
+
+class UnitExtractor(nn.Module):
+    """Vocoder interface to run vocoder models through hub. Currently we only support unit vocoder"""
+
+    def __init__(
+        self,
+        model: Wav2Vec2Model,
+        device: Device,
+        layer: int = 35,
+    ):
+        super().__init__()
+        self.model = model
+        self.model.eval()
+        self.model.to(device=device)
+        self.device = device
+        self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+
+    def predict(
+        self,
+        audio: Union[str, torch.Tensor],
+    ) -> Tuple[List[Tensor], int]:
+
+        if isinstance(audio, str):
+            with Path(audio).open("rb") as fb:
+                block = MemoryBlock(fb.read())
+            decoded_audio = self.decode_audio(block)
+        else:
+            decoded_audio = {
+                "waveform": audio,
+                "sample_rate": 16000.0,
+                "format": -1,
+            }
+
+
+if __name__ == "__main__":
+    audio = "/large_experiments/seamless/ust/data/TTS/vocoder_training/audio_wavs/multi_spkr/eng/eng_LJSpeech-1.1_0/LJ003-0001.wav"
+    model = load_wav2vec2_model("xlsr_1b_v2")
+    unit_extractor = UnitExtractor(model, device=Device("cuda:0"))
+    wav, sr = unit_extractor.predict(audio)

+ 215 - 0
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -0,0 +1,215 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from fairseq2.nn.transformer import TransformerNormOrder
+from fairseq2.models.wav2vec2 import (
+    Wav2Vec2EncoderConfig,
+    Wav2Vec2Config,
+    wav2vec2_arch,
+    Wav2Vec2Model,
+    Wav2Vec2Builder,
+    Wav2Vec2EncoderBuilder,
+)
+from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.utils.model_loader import ModelConfigLoader
+from fairseq2.typing import DataType, Device
+from fairseq2.models.sequence import SequenceBatch
+
+
+from seamless_communication.assets import asset_store, download_manager
+
+
+import torch
+from typing import Optional
+
+from torch import Tensor
+
+wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
+wav2vec2_arch = wav2vec2_archs.marker
+
+
+def _encoder_xlsr2_1b_v2() -> Wav2Vec2EncoderConfig:
+    layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
+
+    return Wav2Vec2EncoderConfig(
+        model_dim=1280,
+        max_seq_len=4096,
+        feature_dim=512,
+        use_fbank=False,
+        first_pass_dropout_p=0.0,
+        layer_norm_features=False,
+        feature_extractor_layer_descs=layer_descs,
+        feature_extractor_bias=True,
+        feature_extractor_layer_norm_convs=True,
+        feature_grad_scale=1.0,
+        num_fbank_channels=0,
+        fbank_stride=0,
+        sample_fbank_every_k=0,
+        pos_encoder_type="conv",
+        pos_encoder_depth=1,
+        pos_conv_kernel_size=128,
+        num_pos_conv_groups=16,
+        use_conformer=False,
+        num_encoder_layers=48,
+        num_encoder_attn_heads=16,
+        ffn_inner_dim=5120,
+        dropout_p=0.0,
+        attn_dropout_p=0.0,
+        layer_drop_p=0.0,
+        norm_order=TransformerNormOrder.PRE,
+        depthwise_conv_kernel_size=0,
+    )
+
+
+@wav2vec2_arch("xlsr2_1b_v2")
+def _xlsr2_1b_v2() -> Wav2Vec2Config:
+    encoder_config = _encoder_xlsr2_1b_v2()
+
+    return Wav2Vec2Config(
+        encoder_config,
+        final_dim=1024,
+        final_proj_bias=True,
+        temporal_mask_span_len=10,
+        max_temporal_mask_prob=0.65,
+        spatial_mask_span_len=10,
+        max_spatial_mask_prob=0.0,
+        quantized_dim=1024,
+        num_codebooks=2,
+        num_codebook_entries=320,
+        codebook_sampling_temperature=(2, 0.1, 0.999995),
+        num_distractors=100,
+        logit_temp=0.1,
+        diversity_loss_weight=0.1,
+    )
+
+
+class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
+    @torch.no_grad()
+    def forward(self, batch: SequenceBatch, out_layer_idx: int):
+        """
+        :param batch:
+            The batch of sequences to process.
+        """
+        print(f"Before run_frontend: {batch.seqs.sum()}")
+        seqs, padding_mask, _, _ = self.run_frontend(batch.seqs, batch.seq_lens)
+        print(f"After run_frontend: {seqs.sum()}")
+        w2v2_layer_output = None
+
+        def layer_output_hook(
+            layer_idx: int,
+            layer_output: Tensor,
+            layer_padding_mask: Optional[Tensor],
+            num_layers: int,
+        ) -> None:
+            nonlocal w2v2_layer_output
+
+            if layer_idx == out_layer_idx:
+                print(f"{layer_idx=}")
+                w2v2_layer_output = layer_output
+
+        # TODO: Should pad for fp16?
+        _, _ = self.encoder(seqs, padding_mask, layer_output_hook)
+
+        assert w2v2_layer_output is not None
+        return w2v2_layer_output
+
+
+class Wav2Vec2LayerOutputBuilder(Wav2Vec2Builder):
+    def build_model(self) -> Wav2Vec2LayerOutputModel:
+        """Build a model."""
+        encoder_frontend = self.encoder_builder.build_frontend()
+
+        encoder = self.encoder_builder.build_encoder()
+
+        masker = self.build_masker()
+
+        quantizer = self.build_quantizer()
+
+        return Wav2Vec2LayerOutputModel(
+            encoder_frontend,
+            encoder,
+            masker,
+            quantizer,
+            self.config.final_dim,
+            self.config.final_proj_bias,
+            self.config.num_distractors,
+            self.config.logit_temp,
+            self.config.diversity_loss_weight,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+def create_wav2vec2_layer_output_model(
+    config: Wav2Vec2Config,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> Wav2Vec2Model:
+    """Create a wav2vec 2.0 model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+    encoder_builder = Wav2Vec2EncoderBuilder(config.encoder_config, device, dtype)
+
+    return Wav2Vec2LayerOutputBuilder(
+        config, encoder_builder, device, dtype
+    ).build_model()
+
+
+load_wav2vec2_layer_output_config = ModelConfigLoader[Wav2Vec2Config](
+    asset_store, wav2vec2_archs
+)
+
+load_wav2vec2_layer_output_model = Wav2Vec2Loader(
+    asset_store,
+    download_manager,
+    create_wav2vec2_layer_output_model,
+    wav2vec2_archs,
+    # `weight_norm` used in `Wav2Vec2PositionEncoder` does not support meta
+    # initialization.
+    use_meta=False,
+)
+
+
+if __name__ == "__main__":
+    from fairseq2.data import Collater
+    from fairseq2.memory import MemoryBlock
+    from fairseq2.data.audio import AudioDecoder
+    from pathlib import Path
+
+    audio = "/large_experiments/seamless/ust/data/TTS/vocoder_training/audio_wavs/multi_spkr/eng/eng_LJSpeech-1.1_0/LJ003-0001.wav"
+    out_layer_idx = 34
+    device = torch.device("cuda:1")
+    decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+    collate = Collater(pad_idx=2, pad_to_multiple=2)
+    decoded_audio = None
+    if isinstance(audio, str):
+        with Path(audio).open("rb") as fb:
+            block = MemoryBlock(fb.read())
+        decoded_audio = decode_audio(block)
+    src = collate(decoded_audio)["waveform"]
+
+    x = torch.tensor(torch.load("/checkpoint/krs/x.pt"), device=device)
+    print(f"After read audio: {x.sum()}, {x.shape}")
+    x = x.unsqueeze(0)
+    import torch.nn.functional as F
+
+    x = F.layer_norm(x, x.shape)
+    # batch.seqs = batch.seqs.view(1, -1)
+
+    print(f"After layer norm: {x.sum()}, {x.shape}")
+    model = load_wav2vec2_layer_output_model(
+        "xlsr2_1b_v2", device=device, dtype=torch.float32
+    )
+    model.eval()
+    batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
+    out = model(batch, out_layer_idx)
+    print(out.sum(), out.shape)