浏览代码

Merge pull request #17 from facebookresearch/unit_extractor

Unit extraction pipeline to extract units from raw audio.
Kaushik Ram Sadagopan 2 年之前
父节点
当前提交
6c01bdfd7f

+ 55 - 0
scripts/m4t/audio_to_units/audio_to_units.py

@@ -0,0 +1,55 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import torch
+from seamless_communication.models.unit_extraction import UnitExtractor
+
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Convert raw audio to units (and optionally audio) using UnitExtractor."
+    )
+    parser.add_argument("audio", type=str, help="Audio WAV file path.")
+    parser.add_argument(
+        "--kmeans_uri",
+        type=str,
+        help="URL path to the K-Means model.",
+        default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
+    )
+    parser.add_argument(
+        "--model_name",
+        type=str,
+        help="Feature extraction model name (`xlsr2_1b_v2`)",
+        default="xlsr2_1b_v2",
+    )
+    parser.add_argument(
+        "--out_layer_number",
+        type=int,
+        help="Layer number of the feature extraction model to pull out features from.",
+        default=35,
+    )
+
+    args = parser.parse_args()
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        logger.info("Running unit_extraction on the GPU.")
+    else:
+        device = torch.device("cpu")
+        logger.info("Running unit_extraction on the CPU.")
+
+    unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
+    units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
+    logger.info(f"Converted to units: {units}")
+
+
+if __name__ == "__main__":
+    main()

+ 10 - 0
src/seamless_communication/assets/cards/xlsr2_1b_v2.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: xlsr2_1b_v2
+model_type: wav2vec2
+model_arch: xlsr2_1b_v2
+checkpoint: "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/xlsr2_1b_v2.pt"

+ 3 - 4
src/seamless_communication/models/inference/translator.py

@@ -54,7 +54,6 @@ class Translator(nn.Module):
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
         dtype: DataType,
-        sample_rate: int = 16000,
     ):
         super().__init__()
         # Load the model.
@@ -80,7 +79,6 @@ class Translator(nn.Module):
         self.vocoder: Vocoder = self.load_model_for_inference(
             load_vocoder_model, vocoder_name_or_card, device, torch.float32
         )
-        self.sample_rate = sample_rate
 
     @staticmethod
     def load_model_for_inference(
@@ -157,6 +155,7 @@ class Translator(nn.Module):
         src_lang: Optional[str] = None,
         spkr: Optional[int] = -1,
         ngram_filtering: bool = False,
+        sample_rate: int = 16000,
     ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
         """
         The main method used to perform inference on all tasks.
@@ -194,7 +193,7 @@ class Translator(nn.Module):
             else:
                 decoded_audio = {
                     "waveform": audio,
-                    "sample_rate": self.sample_rate,
+                    "sample_rate": sample_rate,
                     "format": -1,
                 }
             src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
@@ -226,4 +225,4 @@ class Translator(nn.Module):
         else:
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
             wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
-            return text_out.sentences[0], wav_out, self.sample_rate
+            return text_out.sentences[0], wav_out, sample_rate

+ 14 - 0
src/seamless_communication/models/unit_extraction/__init__.py

@@ -0,0 +1,14 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from seamless_communication.models.unit_extraction.unit_extraction import (
+    UnitExtractor as UnitExtractor,
+)
+from seamless_communication.models.unit_extraction.kmeans import (
+    KmeansModel as KmeansModel,
+)
+from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
+    Wav2Vec2LayerOutputModel as Wav2Vec2LayerOutputModel,
+)

+ 31 - 0
src/seamless_communication/models/unit_extraction/kmeans.py

@@ -0,0 +1,31 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+import numpy as np
+from fairseq2.typing import Device
+from seamless_communication.assets import download_manager
+
+
+class KmeansModel(nn.Module):
+    def __init__(self, kmeans_uri: str, device: Device):
+        super().__init__()
+        km_path = download_manager.download_checkpoint(kmeans_uri, kmeans_uri)
+        km_model = np.load(km_path)
+        centroids_numpy = km_model.transpose()
+        centroids = torch.from_numpy(centroids_numpy)
+
+        self.centroids = centroids.to(device)
+        self.centroid_norm = (self.centroids**2).sum(0, keepdims=True)
+
+    def forward(self, x: Tensor) -> Tensor:
+        dist: Tensor = (
+            x.pow(2).sum(1, keepdim=True)
+            - 2 * torch.matmul(x, self.centroids)
+            + self.centroid_norm
+        )
+        return dist.argmin(dim=-1)

+ 87 - 0
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -0,0 +1,87 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Union
+from pathlib import Path
+import torch
+
+from itertools import groupby
+from fairseq2.typing import DataType, Device
+from torch import Tensor, nn
+from fairseq2.data.audio import AudioDecoder
+from fairseq2.data import Collater
+import torch.nn.functional as F
+from fairseq2.memory import MemoryBlock
+from fairseq2.assets.card import AssetCard
+from fairseq2.models.sequence import SequenceBatch
+from fairseq2.models.wav2vec2 import Wav2Vec2Model
+from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
+    load_wav2vec2_model,
+    Wav2Vec2LayerOutputModel,
+)
+from seamless_communication.models.unit_extraction.kmeans import KmeansModel
+from seamless_communication.models.inference import Translator
+from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
+
+
+class UnitExtractor(nn.Module):
+    """Unit Extractor which converts raw audio into units."""
+
+    def __init__(
+        self,
+        model_name_or_card: Union[str, AssetCard],
+        kmeans_uri: str,
+        device: Device,
+        dtype: DataType = torch.float32,
+    ):
+        super().__init__()
+        self.wav2vec2_model: Wav2Vec2Model = Translator.load_model_for_inference(
+            load_wav2vec2_model, model_name_or_card, device, dtype
+        )
+        self.model = Wav2Vec2LayerOutputModel(self.wav2vec2_model)
+        self.device = device
+        self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+        self.collate = Collater(pad_idx=2, pad_to_multiple=2)
+        self.kmeans_model = KmeansModel(kmeans_uri, device)
+
+    @torch.inference_mode()
+    def predict(
+        self,
+        audio: Union[str, torch.Tensor],
+        out_layer_idx: int,
+        sample_rate: int = 16000,
+    ) -> Tensor:
+        if isinstance(audio, str):
+            with Path(audio).open("rb") as fb:
+                block = MemoryBlock(fb.read())
+            decoded_audio = self.decode_audio(block)
+        else:
+            decoded_audio = {
+                "waveform": audio,
+                "sample_rate": sample_rate,
+                "format": -1,
+            }
+        src = self.collate(decoded_audio)["waveform"]
+        x = src["seqs"]
+        x = x.view(1, -1)
+        x = F.layer_norm(x, x.shape)
+        batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
+        features = self.model(batch, out_layer_idx).squeeze(0)
+        units = self.kmeans_model(features)
+        return units
+
+    @staticmethod
+    def resynthesize_audio(units, src_lang, device, vocoder_name="vocoder_36langs"):
+        def reduce_list(lst):
+            return [key for key, _ in groupby(lst)]
+
+        reduced_units = reduce_list(units.cpu().tolist())
+
+        vocoder: Vocoder = Translator.load_model_for_inference(
+            load_vocoder_model, vocoder_name, device, torch.float32
+        )
+        wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
+        return wav

+ 130 - 0
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -0,0 +1,130 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
+from fairseq2.models.wav2vec2 import (
+    Wav2Vec2EncoderConfig,
+    Wav2Vec2Config,
+    wav2vec2_arch,
+    Wav2Vec2Model,
+    create_wav2vec2_model,
+    Wav2Vec2Frontend,
+)
+from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.models.sequence import SequenceBatch
+
+
+from seamless_communication.assets import asset_store, download_manager
+
+
+import torch
+from typing import Optional
+
+from torch import Tensor
+import torch.nn as nn
+
+
+wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
+wav2vec2_arch = wav2vec2_archs.marker
+
+
+def _encoder_xlsr2_1b_v2() -> Wav2Vec2EncoderConfig:
+    layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
+
+    return Wav2Vec2EncoderConfig(
+        model_dim=1280,
+        max_seq_len=4096,
+        feature_dim=512,
+        use_fbank=False,
+        first_pass_dropout_p=0.0,
+        layer_norm_features=False,
+        feature_extractor_layer_descs=layer_descs,
+        feature_extractor_bias=True,
+        feature_extractor_layer_norm_convs=True,
+        feature_grad_scale=1.0,
+        num_fbank_channels=0,
+        fbank_stride=0,
+        sample_fbank_every_k=0,
+        pos_encoder_type="conv",
+        pos_encoder_depth=1,
+        pos_conv_kernel_size=128,
+        num_pos_conv_groups=16,
+        use_conformer=False,
+        num_encoder_layers=48,
+        num_encoder_attn_heads=16,
+        ffn_inner_dim=5120,
+        dropout_p=0.1,
+        attn_dropout_p=0.1,
+        layer_drop_p=0.0,
+        norm_order=TransformerNormOrder.PRE,
+        depthwise_conv_kernel_size=0,
+    )
+
+
+@wav2vec2_arch("xlsr2_1b_v2")
+def _xlsr2_1b_v2() -> Wav2Vec2Config:
+    encoder_config = _encoder_xlsr2_1b_v2()
+
+    return Wav2Vec2Config(
+        encoder_config,
+        final_dim=1024,
+        final_proj_bias=True,
+        temporal_mask_span_len=10,
+        max_temporal_mask_prob=0.65,
+        spatial_mask_span_len=10,
+        max_spatial_mask_prob=0.0,
+        quantized_dim=1024,
+        num_codebooks=2,
+        num_codebook_entries=320,
+        codebook_sampling_temperature=(2, 0.1, 0.999995),
+        num_distractors=100,
+        logit_temp=0.1,
+        diversity_loss_weight=0.1,
+    )
+
+
+load_wav2vec2_model = Wav2Vec2Loader(
+    asset_store,
+    download_manager,
+    create_wav2vec2_model,
+    wav2vec2_archs,
+)
+
+
+class Wav2Vec2LayerOutputModel(nn.Module):
+    encoder_frontend: Wav2Vec2Frontend
+    encoder: TransformerEncoder
+
+    def __init__(self, w2v2: Wav2Vec2Model):
+        super().__init__()
+
+        self.encoder_frontend = w2v2.encoder_frontend
+        self.encoder = w2v2.encoder
+
+    @torch.inference_mode()
+    def forward(self, batch: SequenceBatch, out_layer_idx: int):
+        """
+        :param batch:
+            The batch of sequences to process.
+        """
+        seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.seq_lens)
+        w2v2_layer_output = None
+
+        def layer_output_hook(
+            layer_idx: int,
+            layer_output: Tensor,
+            layer_padding_mask: Optional[Tensor],
+            num_layers: int,
+        ) -> None:
+            nonlocal w2v2_layer_output
+
+            if layer_idx == out_layer_idx:
+                w2v2_layer_output = layer_output
+
+        _, _ = self.encoder(seqs, padding_mask, layer_output_hook)
+
+        assert w2v2_layer_output is not None
+        return w2v2_layer_output