Переглянути джерело

Implementing the UnitExtractor module.

Kaushik Ram Sadagopan 2 роки тому
батько
коміт
3db2896fa1

+ 1 - 1
src/seamless_communication/assets/cards/xlsr2_1b_v2.yaml

@@ -7,4 +7,4 @@
 name: xlsr2_1b_v2
 model_type: wav2vec2
 model_arch: xlsr2_1b_v2
-checkpoint: "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/xlsr2_1b_v2.pt"
+checkpoint: "file://private/home/changhan/data/models/wav2vec2/xlsr2_1b_v2.pt"

+ 18 - 7
src/seamless_communication/models/inference/translator.py

@@ -4,7 +4,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -15,7 +15,7 @@ from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
 from fairseq2.memory import MemoryBlock
-from fairseq2.typing import Device
+from fairseq2.typing import DataType, Device
 from torch import Tensor
 from enum import Enum, auto
 
@@ -54,10 +54,9 @@ class Translator(nn.Module):
     ):
         super().__init__()
         # Load the model.
-        self.model: UnitYModel = load_unity_model(
-            model_name_or_card, device=device, dtype=torch.float16
+        self.model: UnitYModel = self.load_model_for_inference(
+            load_unity_model, model_name_or_card, device, torch.float16
         )
-        self.model.eval()
         self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
         self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
         self.device = device
@@ -74,10 +73,22 @@ class Translator(nn.Module):
             pad_idx=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
         # Load the vocoder.
-        self.vocoder: Vocoder = load_vocoder_model(vocoder_name_or_card, device=device)
-        self.vocoder.eval()
+        self.vocoder = self.load_model_for_inference(
+            load_vocoder_model, vocoder_name_or_card, device, torch.float32
+        )
         self.sr = sample_rate
 
+    @staticmethod
+    def load_model_for_inference(
+        load_model_fn: Any,
+        model_name_or_card: Union[str, AssetCard],
+        device: Device,
+        dtype: DataType,
+    ) -> nn.Module:
+        model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
+        model.eval()
+        return model
+
     @classmethod
     def get_prediction(
         cls,

+ 4 - 16
src/seamless_communication/models/unit_extraction/kmeans.py

@@ -13,14 +13,12 @@ from seamless_communication.assets import download_manager
 
 
 class KmeansModel(nn.Module):
-    @staticmethod
-    def load_model(km_path: Path, device: Device) -> "KmeansModel":
+    def __init__(self, kmeans_uri: str, device: Device):
+        super().__init__()
+        km_path = download_manager.download_checkpoint(kmeans_uri, kmeans_uri)
         km_model = np.load(km_path)
         centroids_numpy = km_model.transpose()
-        return KmeansModel(torch.from_numpy(centroids_numpy), device)
-
-    def __init__(self, centroids: Tensor, device: Device):
-        super().__init__()
+        centroids = torch.from_numpy(centroids_numpy)
 
         self.centroids = nn.Parameter(centroids, requires_grad=False).to(device)
         self.centroid_norm = nn.Parameter(
@@ -34,13 +32,3 @@ class KmeansModel(nn.Module):
             + self.centroid_norm
         )
         return dist.argmin(dim=-1)
-
-
-if __name__ == "__main__":
-    kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
-    km_path = download_manager.download_checkpoint(kmeans_uri, "kmeans_10k")
-    device = torch.device("cuda:1")
-    model = KmeansModel.load_model(km_path, device)
-    t = torch.randn((1000, 1280), device=device, dtype=torch.float32)
-    units = model(t)
-    print(units)

+ 34 - 10
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -11,32 +11,45 @@ import torch
 from fairseq2.typing import Device
 from torch import Tensor, nn
 from fairseq2.data.audio import AudioDecoder
-from fairseq2.models.wav2vec2 import load_wav2vec2_model, Wav2Vec2Model
+from fairseq2.data import Collater
+import torch.nn.functional as F
 from fairseq2.data.typing import StringLike
 from fairseq2.memory import MemoryBlock
+from fairseq2.assets.card import AssetCard
+from fairseq2.models.sequence import SequenceBatch
+from seamless_communication.models.unit_extraction.wav2vec2_layer_output import (
+    load_wav2vec2_layer_output_model,
+    Wav2Vec2LayerOutputModel,
+)
+from seamless_communication.models.unit_extraction.kmeans import KmeansModel
+from seamless_communication.models.inference import Translator
 
 
 class UnitExtractor(nn.Module):
-    """Vocoder interface to run vocoder models through hub. Currently we only support unit vocoder"""
+    """Unit Extractor which converts raw audio into units."""
 
     def __init__(
         self,
-        model: Wav2Vec2Model,
+        model_name_or_card: Union[str, AssetCard],
+        kmeans_uri: str,
         device: Device,
         layer: int = 35,
     ):
         super().__init__()
-        self.model = model
-        self.model.eval()
-        self.model.to(device=device)
+        self.model: Wav2Vec2LayerOutputModel = Translator.load_model_for_inference(
+            load_wav2vec2_layer_output_model, model_name_or_card, device, torch.float32
+        )
         self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+        self.collate = Collater(pad_idx=2, pad_to_multiple=2)
+        self.kmeans_model = KmeansModel(kmeans_uri, device)
 
+    @torch.no_grad()
     def predict(
         self,
         audio: Union[str, torch.Tensor],
+        out_layer_idx: int,
     ) -> Tuple[List[Tensor], int]:
-
         if isinstance(audio, str):
             with Path(audio).open("rb") as fb:
                 block = MemoryBlock(fb.read())
@@ -47,10 +60,21 @@ class UnitExtractor(nn.Module):
                 "sample_rate": 16000.0,
                 "format": -1,
             }
+        src = self.collate(decoded_audio)["waveform"]
+        x = src["seqs"]
+        x = F.layer_norm(x, x.shape)
+        x = x.view(1, -1)
+        batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
+        features = self.model(batch, out_layer_idx).squeeze(0)
+        units = self.kmeans_model(features)
+        return units
 
 
 if __name__ == "__main__":
+    kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
     audio = "/large_experiments/seamless/ust/data/TTS/vocoder_training/audio_wavs/multi_spkr/eng/eng_LJSpeech-1.1_0/LJ003-0001.wav"
-    model = load_wav2vec2_model("xlsr_1b_v2")
-    unit_extractor = UnitExtractor(model, device=Device("cuda:0"))
-    wav, sr = unit_extractor.predict(audio)
+    device = torch.device("cuda:1")
+    unit_extractor = UnitExtractor("xlsr2_1b_v2", kmeans_uri, device=Device("cuda:0"))
+    out_layer_number = 35
+    units = unit_extractor.predict(audio, out_layer_number - 1)
+    print(units.shape, units.dtype, units.device)

+ 0 - 39
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -93,9 +93,7 @@ class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
         :param batch:
             The batch of sequences to process.
         """
-        print(f"Before run_frontend: {batch.seqs.sum()}")
         seqs, padding_mask, _, _ = self.run_frontend(batch.seqs, batch.seq_lens)
-        print(f"After run_frontend: {seqs.sum()}")
         w2v2_layer_output = None
 
         def layer_output_hook(
@@ -107,7 +105,6 @@ class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
             nonlocal w2v2_layer_output
 
             if layer_idx == out_layer_idx:
-                print(f"{layer_idx=}")
                 w2v2_layer_output = layer_output
 
         # TODO: Should pad for fp16?
@@ -177,39 +174,3 @@ load_wav2vec2_layer_output_model = Wav2Vec2Loader(
     # initialization.
     use_meta=False,
 )
-
-
-if __name__ == "__main__":
-    from fairseq2.data import Collater
-    from fairseq2.memory import MemoryBlock
-    from fairseq2.data.audio import AudioDecoder
-    from pathlib import Path
-
-    audio = "/large_experiments/seamless/ust/data/TTS/vocoder_training/audio_wavs/multi_spkr/eng/eng_LJSpeech-1.1_0/LJ003-0001.wav"
-    out_layer_idx = 34
-    device = torch.device("cuda:1")
-    decode_audio = AudioDecoder(dtype=torch.float32, device=device)
-    collate = Collater(pad_idx=2, pad_to_multiple=2)
-    decoded_audio = None
-    if isinstance(audio, str):
-        with Path(audio).open("rb") as fb:
-            block = MemoryBlock(fb.read())
-        decoded_audio = decode_audio(block)
-    src = collate(decoded_audio)["waveform"]
-
-    x = torch.tensor(torch.load("/checkpoint/krs/x.pt"), device=device)
-    print(f"After read audio: {x.sum()}, {x.shape}")
-    x = x.unsqueeze(0)
-    import torch.nn.functional as F
-
-    x = F.layer_norm(x, x.shape)
-    # batch.seqs = batch.seqs.view(1, -1)
-
-    print(f"After layer norm: {x.sum()}, {x.shape}")
-    model = load_wav2vec2_layer_output_model(
-        "xlsr2_1b_v2", device=device, dtype=torch.float32
-    )
-    model.eval()
-    batch = SequenceBatch(seqs=x, seq_lens=src["seq_lens"])
-    out = model(batch, out_layer_idx)
-    print(out.sum(), out.shape)