Explorar el Código

Refactoring wav2vec2_layer_output and getting unit_extraction parity with fairseq.

Kaushik Ram Sadagopan hace 2 años
padre
commit
e2da150258

+ 1 - 37
scripts/m4t/audio_to_units/audio_to_units.py

@@ -6,11 +6,7 @@
 import argparse
 import logging
 import torch
-import torchaudio
 from seamless_communication.models.unit_extraction import UnitExtractor
-from seamless_communication.models.inference import Translator
-from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
-from itertools import groupby
 
 
 logging.basicConfig(level=logging.INFO)
@@ -34,24 +30,12 @@ def main():
         help="Feature extraction model name (`xlsr2_1b_v2`)",
         default="xlsr2_1b_v2",
     )
-    parser.add_argument(
-        "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
-    )
     parser.add_argument(
         "--out_layer_number",
         type=int,
         help="Layer number of the feature extraction model to pull out features from.",
         default=35,
     )
-    parser.add_argument(
-        "--output_path",
-        type=str,
-        help="Path to save the generated audio.",
-        default=None,
-    )
-    parser.add_argument(
-        "--src_lang", type=str, help="Source language of the audio.", default=None
-    )
 
     args = parser.parse_args()
 
@@ -64,27 +48,7 @@ def main():
 
     unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
     units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
-
-    if args.output_path is not None:
-
-        if args.src_lang is None:
-            raise ValueError("src_lang must be provided to resynthesize the audio.")
-
-        def reduce_list(lst):
-            return [key for key, _ in groupby(lst)]
-
-        reduced_units = reduce_list(units.cpu().tolist())
-
-        vocoder: Vocoder = Translator.load_model_for_inference(
-            load_vocoder_model, args.vocoder_name, device, torch.float32
-        )
-        wav = vocoder(reduced_units, args.src_lang, spkr=-1, dur_prediction=True)
-
-        torchaudio.save(
-            args.output_path,
-            wav[0].cpu(),
-            sample_rate=16000,
-        )
+    logger.info(f"Converted to units: {units}")
 
 
 if __name__ == "__main__":

+ 3 - 15
src/seamless_communication/models/inference/translator.py

@@ -54,7 +54,6 @@ class Translator(nn.Module):
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
         dtype: DataType,
-        sample_rate: int = 16000,
     ):
         super().__init__()
         # Load the model.
@@ -80,7 +79,6 @@ class Translator(nn.Module):
         self.vocoder: Vocoder = self.load_model_for_inference(
             load_vocoder_model, vocoder_name_or_card, device, torch.float32
         )
-        self.sample_rate = sample_rate
 
     @staticmethod
     def load_model_for_inference(
@@ -93,17 +91,6 @@ class Translator(nn.Module):
         model.eval()
         return model
 
-    @staticmethod
-    def load_model_for_inference(
-        load_model_fn: Any,
-        model_name_or_card: Union[str, AssetCard],
-        device: Device,
-        dtype: DataType,
-    ) -> nn.Module:
-        model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
-        model.eval()
-        return model
-
     @classmethod
     def get_prediction(
         cls,
@@ -168,6 +155,7 @@ class Translator(nn.Module):
         src_lang: Optional[str] = None,
         spkr: Optional[int] = -1,
         ngram_filtering: bool = False,
+        sample_rate: int = 16000,
     ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
         """
         The main method used to perform inference on all tasks.
@@ -205,7 +193,7 @@ class Translator(nn.Module):
             else:
                 decoded_audio = {
                     "waveform": audio,
-                    "sample_rate": self.sample_rate,
+                    "sample_rate": sample_rate,
                     "format": -1,
                 }
             src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
@@ -237,4 +225,4 @@ class Translator(nn.Module):
         else:
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
             wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
-            return text_out.sentences[0], wav_out, self.sample_rate
+            return text_out.sentences[0], wav_out, sample_rate

+ 0 - 3
src/seamless_communication/models/unit_extraction/__init__.py

@@ -9,9 +9,6 @@ from seamless_communication.models.unit_extraction.unit_extraction import (
 from seamless_communication.models.unit_extraction.kmeans import (
     KmeansModel as KmeansModel,
 )
-from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
-    load_wav2vec2_layer_output_model as load_wav2vec2_layer_output_model,
-)
 from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
     Wav2Vec2LayerOutputModel as Wav2Vec2LayerOutputModel,
 )

+ 2 - 5
src/seamless_communication/models/unit_extraction/kmeans.py

@@ -7,7 +7,6 @@
 import torch
 from torch import Tensor, nn
 import numpy as np
-from pathlib import Path
 from fairseq2.typing import Device
 from seamless_communication.assets import download_manager
 
@@ -20,10 +19,8 @@ class KmeansModel(nn.Module):
         centroids_numpy = km_model.transpose()
         centroids = torch.from_numpy(centroids_numpy)
 
-        self.centroids = nn.Parameter(centroids, requires_grad=False).to(device)
-        self.centroid_norm = nn.Parameter(
-            (centroids**2).sum(0, keepdims=True), requires_grad=False
-        ).to(device)
+        self.centroids = centroids.to(device)
+        self.centroid_norm = (self.centroids**2).sum(0, keepdims=True)
 
     def forward(self, x: Tensor) -> Tensor:
         dist: Tensor = (

+ 28 - 11
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -4,25 +4,27 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import List, Tuple, Union
+from typing import Union
 from pathlib import Path
 import torch
 
-from fairseq2.typing import Device
+from itertools import groupby
+from fairseq2.typing import DataType, Device
 from torch import Tensor, nn
 from fairseq2.data.audio import AudioDecoder
 from fairseq2.data import Collater
 import torch.nn.functional as F
-from fairseq2.data.typing import StringLike
 from fairseq2.memory import MemoryBlock
 from fairseq2.assets.card import AssetCard
 from fairseq2.models.sequence import SequenceBatch
+from fairseq2.models.wav2vec2 import Wav2Vec2Model
 from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
-    load_wav2vec2_layer_output_model,
+    load_wav2vec2_model,
     Wav2Vec2LayerOutputModel,
 )
 from seamless_communication.models.unit_extraction.kmeans import KmeansModel
 from seamless_communication.models.inference import Translator
+from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
 
 
 class UnitExtractor(nn.Module):
@@ -33,23 +35,25 @@ class UnitExtractor(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         kmeans_uri: str,
         device: Device,
-        layer: int = 35,
+        dtype: DataType = torch.float32,
     ):
         super().__init__()
-        self.model: Wav2Vec2LayerOutputModel = Translator.load_model_for_inference(
-            load_wav2vec2_layer_output_model, model_name_or_card, device, torch.float32
+        self.wav2vec2_model: Wav2Vec2Model = Translator.load_model_for_inference(
+            load_wav2vec2_model, model_name_or_card, device, dtype
         )
+        self.model = Wav2Vec2LayerOutputModel(self.wav2vec2_model)
         self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.collate = Collater(pad_idx=2, pad_to_multiple=2)
         self.kmeans_model = KmeansModel(kmeans_uri, device)
 
-    @torch.no_grad()
+    @torch.inference_mode()
     def predict(
         self,
         audio: Union[str, torch.Tensor],
         out_layer_idx: int,
-    ) -> Tuple[List[Tensor], int]:
+        sample_rate: int = 16000,
+    ) -> Tensor:
         if isinstance(audio, str):
             with Path(audio).open("rb") as fb:
                 block = MemoryBlock(fb.read())
@@ -57,14 +61,27 @@ class UnitExtractor(nn.Module):
         else:
             decoded_audio = {
                 "waveform": audio,
-                "sample_rate": 16000.0,
+                "sample_rate": sample_rate,
                 "format": -1,
             }
         src = self.collate(decoded_audio)["waveform"]
         x = src["seqs"]
-        x = F.layer_norm(x, x.shape)
         x = x.view(1, -1)
+        x = F.layer_norm(x, x.shape)
         batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
         features = self.model(batch, out_layer_idx).squeeze(0)
         units = self.kmeans_model(features)
         return units
+
+    @staticmethod
+    def resynthesize_audio(units, src_lang, device, vocoder_name="vocoder_36langs"):
+        def reduce_list(lst):
+            return [key for key, _ in groupby(lst)]
+
+        reduced_units = reduce_list(units.cpu().tolist())
+
+        vocoder: Vocoder = Translator.load_model_for_inference(
+            load_vocoder_model, vocoder_name, device, torch.float32
+        )
+        wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
+        return wav

+ 25 - 71
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -3,19 +3,17 @@
 #
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
-from fairseq2.nn.transformer import TransformerNormOrder
+from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
 from fairseq2.models.wav2vec2 import (
     Wav2Vec2EncoderConfig,
     Wav2Vec2Config,
     wav2vec2_arch,
     Wav2Vec2Model,
-    Wav2Vec2Builder,
-    Wav2Vec2EncoderBuilder,
+    create_wav2vec2_model,
+    Wav2Vec2Frontend,
 )
 from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.models.utils.model_loader import ModelConfigLoader
-from fairseq2.typing import DataType, Device
 from fairseq2.models.sequence import SequenceBatch
 
 
@@ -26,6 +24,8 @@ import torch
 from typing import Optional
 
 from torch import Tensor
+import torch.nn as nn
+
 
 wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
 wav2vec2_arch = wav2vec2_archs.marker
@@ -86,14 +86,31 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:
     )
 
 
-class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
-    @torch.no_grad()
+load_wav2vec2_model = Wav2Vec2Loader(
+    asset_store,
+    download_manager,
+    create_wav2vec2_model,
+    wav2vec2_archs,
+)
+
+
+class Wav2Vec2LayerOutputModel(nn.Module):
+    encoder_frontend: Wav2Vec2Frontend
+    encoder: TransformerEncoder
+
+    def __init__(self, w2v2: Wav2Vec2Model):
+        super().__init__()
+
+        self.encoder_frontend = w2v2.encoder_frontend
+        self.encoder = w2v2.encoder
+
+    @torch.inference_mode()
     def forward(self, batch: SequenceBatch, out_layer_idx: int):
         """
         :param batch:
             The batch of sequences to process.
         """
-        seqs, padding_mask, _, _ = self.run_frontend(batch.seqs, batch.seq_lens)
+        seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.seq_lens)
         w2v2_layer_output = None
 
         def layer_output_hook(
@@ -107,70 +124,7 @@ class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
             if layer_idx == out_layer_idx:
                 w2v2_layer_output = layer_output
 
-        # TODO: Should pad for fp16?
         _, _ = self.encoder(seqs, padding_mask, layer_output_hook)
 
         assert w2v2_layer_output is not None
         return w2v2_layer_output
-
-
-class Wav2Vec2LayerOutputBuilder(Wav2Vec2Builder):
-    def build_model(self) -> Wav2Vec2LayerOutputModel:
-        """Build a model."""
-        encoder_frontend = self.encoder_builder.build_frontend()
-
-        encoder = self.encoder_builder.build_encoder()
-
-        masker = self.build_masker()
-
-        quantizer = self.build_quantizer()
-
-        return Wav2Vec2LayerOutputModel(
-            encoder_frontend,
-            encoder,
-            masker,
-            quantizer,
-            self.config.final_dim,
-            self.config.final_proj_bias,
-            self.config.num_distractors,
-            self.config.logit_temp,
-            self.config.diversity_loss_weight,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-
-def create_wav2vec2_layer_output_model(
-    config: Wav2Vec2Config,
-    device: Optional[Device] = None,
-    dtype: Optional[DataType] = None,
-) -> Wav2Vec2Model:
-    """Create a wav2vec 2.0 model.
-
-    :param config:
-        The configuration to use.
-    :param device:
-        The device on which to initialize modules.
-    :param dtype:
-        The data type of module parameters and buffers.
-    """
-    encoder_builder = Wav2Vec2EncoderBuilder(config.encoder_config, device, dtype)
-
-    return Wav2Vec2LayerOutputBuilder(
-        config, encoder_builder, device, dtype
-    ).build_model()
-
-
-load_wav2vec2_layer_output_config = ModelConfigLoader[Wav2Vec2Config](
-    asset_store, wav2vec2_archs
-)
-
-load_wav2vec2_layer_output_model = Wav2Vec2Loader(
-    asset_store,
-    download_manager,
-    create_wav2vec2_layer_output_model,
-    wav2vec2_archs,
-    # `weight_norm` used in `Wav2Vec2PositionEncoder` does not support meta
-    # initialization.
-    use_meta=False,
-)