Browse Source

Refactoring wav2vec2_layer_output and getting unit_extraction parity with fairseq.

Kaushik Ram Sadagopan 2 years ago
parent
commit
e2da150258

+ 1 - 37
scripts/m4t/audio_to_units/audio_to_units.py

@@ -6,11 +6,7 @@
 import argparse
 import argparse
 import logging
 import logging
 import torch
 import torch
-import torchaudio
 from seamless_communication.models.unit_extraction import UnitExtractor
 from seamless_communication.models.unit_extraction import UnitExtractor
-from seamless_communication.models.inference import Translator
-from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
-from itertools import groupby
 
 
 
 
 logging.basicConfig(level=logging.INFO)
 logging.basicConfig(level=logging.INFO)
@@ -34,24 +30,12 @@ def main():
         help="Feature extraction model name (`xlsr2_1b_v2`)",
         help="Feature extraction model name (`xlsr2_1b_v2`)",
         default="xlsr2_1b_v2",
         default="xlsr2_1b_v2",
     )
     )
-    parser.add_argument(
-        "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
-    )
     parser.add_argument(
     parser.add_argument(
         "--out_layer_number",
         "--out_layer_number",
         type=int,
         type=int,
         help="Layer number of the feature extraction model to pull out features from.",
         help="Layer number of the feature extraction model to pull out features from.",
         default=35,
         default=35,
     )
     )
-    parser.add_argument(
-        "--output_path",
-        type=str,
-        help="Path to save the generated audio.",
-        default=None,
-    )
-    parser.add_argument(
-        "--src_lang", type=str, help="Source language of the audio.", default=None
-    )
 
 
     args = parser.parse_args()
     args = parser.parse_args()
 
 
@@ -64,27 +48,7 @@ def main():
 
 
     unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
     unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
     units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
     units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
-
-    if args.output_path is not None:
-
-        if args.src_lang is None:
-            raise ValueError("src_lang must be provided to resynthesize the audio.")
-
-        def reduce_list(lst):
-            return [key for key, _ in groupby(lst)]
-
-        reduced_units = reduce_list(units.cpu().tolist())
-
-        vocoder: Vocoder = Translator.load_model_for_inference(
-            load_vocoder_model, args.vocoder_name, device, torch.float32
-        )
-        wav = vocoder(reduced_units, args.src_lang, spkr=-1, dur_prediction=True)
-
-        torchaudio.save(
-            args.output_path,
-            wav[0].cpu(),
-            sample_rate=16000,
-        )
+    logger.info(f"Converted to units: {units}")
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 3 - 15
src/seamless_communication/models/inference/translator.py

@@ -54,7 +54,6 @@ class Translator(nn.Module):
         vocoder_name_or_card: Union[str, AssetCard],
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
         device: Device,
         dtype: DataType,
         dtype: DataType,
-        sample_rate: int = 16000,
     ):
     ):
         super().__init__()
         super().__init__()
         # Load the model.
         # Load the model.
@@ -80,7 +79,6 @@ class Translator(nn.Module):
         self.vocoder: Vocoder = self.load_model_for_inference(
         self.vocoder: Vocoder = self.load_model_for_inference(
             load_vocoder_model, vocoder_name_or_card, device, torch.float32
             load_vocoder_model, vocoder_name_or_card, device, torch.float32
         )
         )
-        self.sample_rate = sample_rate
 
 
     @staticmethod
     @staticmethod
     def load_model_for_inference(
     def load_model_for_inference(
@@ -93,17 +91,6 @@ class Translator(nn.Module):
         model.eval()
         model.eval()
         return model
         return model
 
 
-    @staticmethod
-    def load_model_for_inference(
-        load_model_fn: Any,
-        model_name_or_card: Union[str, AssetCard],
-        device: Device,
-        dtype: DataType,
-    ) -> nn.Module:
-        model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
-        model.eval()
-        return model
-
     @classmethod
     @classmethod
     def get_prediction(
     def get_prediction(
         cls,
         cls,
@@ -168,6 +155,7 @@ class Translator(nn.Module):
         src_lang: Optional[str] = None,
         src_lang: Optional[str] = None,
         spkr: Optional[int] = -1,
         spkr: Optional[int] = -1,
         ngram_filtering: bool = False,
         ngram_filtering: bool = False,
+        sample_rate: int = 16000,
     ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
     ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
         """
         """
         The main method used to perform inference on all tasks.
         The main method used to perform inference on all tasks.
@@ -205,7 +193,7 @@ class Translator(nn.Module):
             else:
             else:
                 decoded_audio = {
                 decoded_audio = {
                     "waveform": audio,
                     "waveform": audio,
-                    "sample_rate": self.sample_rate,
+                    "sample_rate": sample_rate,
                     "format": -1,
                     "format": -1,
                 }
                 }
             src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
             src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
@@ -237,4 +225,4 @@ class Translator(nn.Module):
         else:
         else:
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
             wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
             wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
-            return text_out.sentences[0], wav_out, self.sample_rate
+            return text_out.sentences[0], wav_out, sample_rate

+ 0 - 3
src/seamless_communication/models/unit_extraction/__init__.py

@@ -9,9 +9,6 @@ from seamless_communication.models.unit_extraction.unit_extraction import (
 from seamless_communication.models.unit_extraction.kmeans import (
 from seamless_communication.models.unit_extraction.kmeans import (
     KmeansModel as KmeansModel,
     KmeansModel as KmeansModel,
 )
 )
-from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
-    load_wav2vec2_layer_output_model as load_wav2vec2_layer_output_model,
-)
 from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
 from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
     Wav2Vec2LayerOutputModel as Wav2Vec2LayerOutputModel,
     Wav2Vec2LayerOutputModel as Wav2Vec2LayerOutputModel,
 )
 )

+ 2 - 5
src/seamless_communication/models/unit_extraction/kmeans.py

@@ -7,7 +7,6 @@
 import torch
 import torch
 from torch import Tensor, nn
 from torch import Tensor, nn
 import numpy as np
 import numpy as np
-from pathlib import Path
 from fairseq2.typing import Device
 from fairseq2.typing import Device
 from seamless_communication.assets import download_manager
 from seamless_communication.assets import download_manager
 
 
@@ -20,10 +19,8 @@ class KmeansModel(nn.Module):
         centroids_numpy = km_model.transpose()
         centroids_numpy = km_model.transpose()
         centroids = torch.from_numpy(centroids_numpy)
         centroids = torch.from_numpy(centroids_numpy)
 
 
-        self.centroids = nn.Parameter(centroids, requires_grad=False).to(device)
-        self.centroid_norm = nn.Parameter(
-            (centroids**2).sum(0, keepdims=True), requires_grad=False
-        ).to(device)
+        self.centroids = centroids.to(device)
+        self.centroid_norm = (self.centroids**2).sum(0, keepdims=True)
 
 
     def forward(self, x: Tensor) -> Tensor:
     def forward(self, x: Tensor) -> Tensor:
         dist: Tensor = (
         dist: Tensor = (

+ 28 - 11
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -4,25 +4,27 @@
 # This source code is licensed under the BSD-style license found in the
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from typing import List, Tuple, Union
+from typing import Union
 from pathlib import Path
 from pathlib import Path
 import torch
 import torch
 
 
-from fairseq2.typing import Device
+from itertools import groupby
+from fairseq2.typing import DataType, Device
 from torch import Tensor, nn
 from torch import Tensor, nn
 from fairseq2.data.audio import AudioDecoder
 from fairseq2.data.audio import AudioDecoder
 from fairseq2.data import Collater
 from fairseq2.data import Collater
 import torch.nn.functional as F
 import torch.nn.functional as F
-from fairseq2.data.typing import StringLike
 from fairseq2.memory import MemoryBlock
 from fairseq2.memory import MemoryBlock
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.models.sequence import SequenceBatch
 from fairseq2.models.sequence import SequenceBatch
+from fairseq2.models.wav2vec2 import Wav2Vec2Model
 from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
 from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
-    load_wav2vec2_layer_output_model,
+    load_wav2vec2_model,
     Wav2Vec2LayerOutputModel,
     Wav2Vec2LayerOutputModel,
 )
 )
 from seamless_communication.models.unit_extraction.kmeans import KmeansModel
 from seamless_communication.models.unit_extraction.kmeans import KmeansModel
 from seamless_communication.models.inference import Translator
 from seamless_communication.models.inference import Translator
+from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
 
 
 
 
 class UnitExtractor(nn.Module):
 class UnitExtractor(nn.Module):
@@ -33,23 +35,25 @@ class UnitExtractor(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         model_name_or_card: Union[str, AssetCard],
         kmeans_uri: str,
         kmeans_uri: str,
         device: Device,
         device: Device,
-        layer: int = 35,
+        dtype: DataType = torch.float32,
     ):
     ):
         super().__init__()
         super().__init__()
-        self.model: Wav2Vec2LayerOutputModel = Translator.load_model_for_inference(
-            load_wav2vec2_layer_output_model, model_name_or_card, device, torch.float32
+        self.wav2vec2_model: Wav2Vec2Model = Translator.load_model_for_inference(
+            load_wav2vec2_model, model_name_or_card, device, dtype
         )
         )
+        self.model = Wav2Vec2LayerOutputModel(self.wav2vec2_model)
         self.device = device
         self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.collate = Collater(pad_idx=2, pad_to_multiple=2)
         self.collate = Collater(pad_idx=2, pad_to_multiple=2)
         self.kmeans_model = KmeansModel(kmeans_uri, device)
         self.kmeans_model = KmeansModel(kmeans_uri, device)
 
 
-    @torch.no_grad()
+    @torch.inference_mode()
     def predict(
     def predict(
         self,
         self,
         audio: Union[str, torch.Tensor],
         audio: Union[str, torch.Tensor],
         out_layer_idx: int,
         out_layer_idx: int,
-    ) -> Tuple[List[Tensor], int]:
+        sample_rate: int = 16000,
+    ) -> Tensor:
         if isinstance(audio, str):
         if isinstance(audio, str):
             with Path(audio).open("rb") as fb:
             with Path(audio).open("rb") as fb:
                 block = MemoryBlock(fb.read())
                 block = MemoryBlock(fb.read())
@@ -57,14 +61,27 @@ class UnitExtractor(nn.Module):
         else:
         else:
             decoded_audio = {
             decoded_audio = {
                 "waveform": audio,
                 "waveform": audio,
-                "sample_rate": 16000.0,
+                "sample_rate": sample_rate,
                 "format": -1,
                 "format": -1,
             }
             }
         src = self.collate(decoded_audio)["waveform"]
         src = self.collate(decoded_audio)["waveform"]
         x = src["seqs"]
         x = src["seqs"]
-        x = F.layer_norm(x, x.shape)
         x = x.view(1, -1)
         x = x.view(1, -1)
+        x = F.layer_norm(x, x.shape)
         batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
         batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
         features = self.model(batch, out_layer_idx).squeeze(0)
         features = self.model(batch, out_layer_idx).squeeze(0)
         units = self.kmeans_model(features)
         units = self.kmeans_model(features)
         return units
         return units
+
+    @staticmethod
+    def resynthesize_audio(units, src_lang, device, vocoder_name="vocoder_36langs"):
+        def reduce_list(lst):
+            return [key for key, _ in groupby(lst)]
+
+        reduced_units = reduce_list(units.cpu().tolist())
+
+        vocoder: Vocoder = Translator.load_model_for_inference(
+            load_vocoder_model, vocoder_name, device, torch.float32
+        )
+        wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
+        return wav

+ 25 - 71
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -3,19 +3,17 @@
 #
 #
 # This source code is licensed under the BSD-style license found in the
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
-from fairseq2.nn.transformer import TransformerNormOrder
+from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
 from fairseq2.models.wav2vec2 import (
 from fairseq2.models.wav2vec2 import (
     Wav2Vec2EncoderConfig,
     Wav2Vec2EncoderConfig,
     Wav2Vec2Config,
     Wav2Vec2Config,
     wav2vec2_arch,
     wav2vec2_arch,
     Wav2Vec2Model,
     Wav2Vec2Model,
-    Wav2Vec2Builder,
-    Wav2Vec2EncoderBuilder,
+    create_wav2vec2_model,
+    Wav2Vec2Frontend,
 )
 )
 from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
 from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.models.utils.model_loader import ModelConfigLoader
-from fairseq2.typing import DataType, Device
 from fairseq2.models.sequence import SequenceBatch
 from fairseq2.models.sequence import SequenceBatch
 
 
 
 
@@ -26,6 +24,8 @@ import torch
 from typing import Optional
 from typing import Optional
 
 
 from torch import Tensor
 from torch import Tensor
+import torch.nn as nn
+
 
 
 wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
 wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
 wav2vec2_arch = wav2vec2_archs.marker
 wav2vec2_arch = wav2vec2_archs.marker
@@ -86,14 +86,31 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:
     )
     )
 
 
 
 
-class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
-    @torch.no_grad()
+load_wav2vec2_model = Wav2Vec2Loader(
+    asset_store,
+    download_manager,
+    create_wav2vec2_model,
+    wav2vec2_archs,
+)
+
+
+class Wav2Vec2LayerOutputModel(nn.Module):
+    encoder_frontend: Wav2Vec2Frontend
+    encoder: TransformerEncoder
+
+    def __init__(self, w2v2: Wav2Vec2Model):
+        super().__init__()
+
+        self.encoder_frontend = w2v2.encoder_frontend
+        self.encoder = w2v2.encoder
+
+    @torch.inference_mode()
     def forward(self, batch: SequenceBatch, out_layer_idx: int):
     def forward(self, batch: SequenceBatch, out_layer_idx: int):
         """
         """
         :param batch:
         :param batch:
             The batch of sequences to process.
             The batch of sequences to process.
         """
         """
-        seqs, padding_mask, _, _ = self.run_frontend(batch.seqs, batch.seq_lens)
+        seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.seq_lens)
         w2v2_layer_output = None
         w2v2_layer_output = None
 
 
         def layer_output_hook(
         def layer_output_hook(
@@ -107,70 +124,7 @@ class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
             if layer_idx == out_layer_idx:
             if layer_idx == out_layer_idx:
                 w2v2_layer_output = layer_output
                 w2v2_layer_output = layer_output
 
 
-        # TODO: Should pad for fp16?
         _, _ = self.encoder(seqs, padding_mask, layer_output_hook)
         _, _ = self.encoder(seqs, padding_mask, layer_output_hook)
 
 
         assert w2v2_layer_output is not None
         assert w2v2_layer_output is not None
         return w2v2_layer_output
         return w2v2_layer_output
-
-
-class Wav2Vec2LayerOutputBuilder(Wav2Vec2Builder):
-    def build_model(self) -> Wav2Vec2LayerOutputModel:
-        """Build a model."""
-        encoder_frontend = self.encoder_builder.build_frontend()
-
-        encoder = self.encoder_builder.build_encoder()
-
-        masker = self.build_masker()
-
-        quantizer = self.build_quantizer()
-
-        return Wav2Vec2LayerOutputModel(
-            encoder_frontend,
-            encoder,
-            masker,
-            quantizer,
-            self.config.final_dim,
-            self.config.final_proj_bias,
-            self.config.num_distractors,
-            self.config.logit_temp,
-            self.config.diversity_loss_weight,
-            device=self.device,
-            dtype=self.dtype,
-        )
-
-
-def create_wav2vec2_layer_output_model(
-    config: Wav2Vec2Config,
-    device: Optional[Device] = None,
-    dtype: Optional[DataType] = None,
-) -> Wav2Vec2Model:
-    """Create a wav2vec 2.0 model.
-
-    :param config:
-        The configuration to use.
-    :param device:
-        The device on which to initialize modules.
-    :param dtype:
-        The data type of module parameters and buffers.
-    """
-    encoder_builder = Wav2Vec2EncoderBuilder(config.encoder_config, device, dtype)
-
-    return Wav2Vec2LayerOutputBuilder(
-        config, encoder_builder, device, dtype
-    ).build_model()
-
-
-load_wav2vec2_layer_output_config = ModelConfigLoader[Wav2Vec2Config](
-    asset_store, wav2vec2_archs
-)
-
-load_wav2vec2_layer_output_model = Wav2Vec2Loader(
-    asset_store,
-    download_manager,
-    create_wav2vec2_layer_output_model,
-    wav2vec2_archs,
-    # `weight_norm` used in `Wav2Vec2PositionEncoder` does not support meta
-    # initialization.
-    use_meta=False,
-)