Просмотр исходного кода

Add PRETSSEL HiFiGAN Vocoder (#78)

* add pretssel hifigan

* update

* update

* update

* update

* update

* --amend

* Update src/seamless_communication/models/vocoder/builder.py

Co-authored-by: Kaushik Ram Sadagopan <krs@fb.com>

* update

---------

Co-authored-by: Kaushik Ram Sadagopan <krs@fb.com>
Changhan Wang 1 год назад
Родитель
Сommit
1cead95b37

+ 58 - 0
scripts/convert_pretssel_hifigan_chkpt.py

@@ -0,0 +1,58 @@
+import numpy as np
+import torch
+
+"""
+upsample_scales -> upsample_rates
+resblock_dilations -> resblock_dilation_sizes
+in_channels -> model_in_dim
+out_channels -> upsample_initial_channel
+"""
+
+
+def main():
+    chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
+    cfg = f"{chkpt_root}/config.yml"
+    # TODO: display cfg
+    chkpt = torch.load(f"{chkpt_root}/checkpoint-400000steps.pkl")
+    del chkpt["model"]["discriminator"]
+    conv_seq_map = {
+        ".1.bias": ".bias",
+        ".1.weight_g": ".weight_g",
+        ".1.weight_v": ".weight_v",
+    }
+
+    def update_key(k):
+        if k.startswith("input_conv"):
+            k = k.replace("input_conv", "conv_pre")
+        elif k.startswith("upsamples"):
+            k = k.replace("upsamples", "ups")
+            for _k, _v in conv_seq_map.items():
+                k = k.replace(_k, _v)
+        elif k.startswith("blocks"):
+            k = k.replace("blocks", "resblocks")
+            for _k, _v in conv_seq_map.items():
+                k = k.replace(_k, _v)
+        elif k.startswith("output_conv"):
+            k = k.replace("output_conv", "conv_post")
+            for _k, _v in conv_seq_map.items():
+                k = k.replace(_k, _v)
+        return k
+
+    chkpt["model"] = {update_key(k): v for k, v in chkpt["model"]["generator"].items()}
+
+    stats_path = f"{chkpt_root}/stats.npy"
+    stats = np.load(stats_path)
+    mean = torch.from_numpy(stats[0].reshape(-1)).float()
+    scale = torch.from_numpy(stats[1].reshape(-1)).float()
+    chkpt["model"]["mean"] = mean
+    chkpt["model"]["scale"] = scale
+
+    for k in ["optimizer", "scheduler", "steps", "epochs"]:
+        del chkpt[k]
+
+    out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
+    torch.save(chkpt, out_path)
+
+
+if __name__ == "__main__":
+    main()

+ 10 - 0
src/seamless_communication/cards/vocoder_pretssel.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: vocoder_pretssel
+model_type: vocoder_mel_hifigan
+model_arch: base_mel
+checkpoint: "file://large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"

+ 12 - 0
src/seamless_communication/models/vocoder/__init__.py

@@ -4,6 +4,9 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+from seamless_communication.models.vocoder.builder import (
+    MelVocoderBuilder as MelVocoderBuilder,
+)
 from seamless_communication.models.vocoder.builder import (
     VocoderBuilder as VocoderBuilder,
 )
@@ -12,8 +15,17 @@ from seamless_communication.models.vocoder.codehifigan import (
     CodeGenerator as CodeGenerator,
 )
 from seamless_communication.models.vocoder.hifigan import Generator as Generator
+from seamless_communication.models.vocoder.loader import (
+    MelVocoderLoader as MelVocoderLoader,
+)
 from seamless_communication.models.vocoder.loader import VocoderLoader as VocoderLoader
+from seamless_communication.models.vocoder.loader import (
+    load_mel_vocoder_model as load_mel_vocoder_model,
+)
 from seamless_communication.models.vocoder.loader import (
     load_vocoder_model as load_vocoder_model,
 )
+from seamless_communication.models.vocoder.melhifigan import (
+    MelGenerator as MelGenerator,
+)
 from seamless_communication.models.vocoder.vocoder import Vocoder as Vocoder

+ 61 - 0
src/seamless_communication/models/vocoder/builder.py

@@ -11,6 +11,7 @@ from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.typing import DataType, Device
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
+from seamless_communication.models.vocoder.melhifigan import MelGenerator
 from seamless_communication.models.vocoder.vocoder import Vocoder
 
 
@@ -135,3 +136,63 @@ def create_vocoder_model(
     """
 
     return VocoderBuilder(config, device=device, dtype=dtype).build_model()
+
+
+mel_vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_mel_hifigan")
+mel_vocoder_arch = mel_vocoder_archs.marker
+
+
+@mel_vocoder_arch("base_mel")
+def _base_mel_vocoder() -> VocoderConfig:
+    return VocoderConfig(
+        upsample_rates=[5, 4, 4, 2],
+        upsample_kernel_sizes=[10, 8, 8, 4],
+        upsample_initial_channel=512,
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        model_in_dim=80,
+        num_embeddings=0,
+        embedding_dim=0,
+        dur_predictor_params={},
+        lang_embedding_dim=0,
+        num_langs=0,
+        spkr_embedding_dim=0,
+        num_spkrs=0,
+        lang_spkr_idx_map={},
+    )
+
+
+class MelVocoderBuilder:
+    config: VocoderConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: VocoderConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        self.config = config
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> MelGenerator:
+        generator = MelGenerator(
+            self.config.upsample_rates,
+            self.config.upsample_kernel_sizes,
+            self.config.upsample_initial_channel,
+            self.config.resblock_kernel_sizes,
+            self.config.resblock_dilation_sizes,
+            self.config.model_in_dim,
+        )
+        generator.to(dtype=self.dtype, device=self.device)
+        return generator
+
+
+def create_mel_vocoder_model(
+    config: VocoderConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> MelGenerator:
+    return MelVocoderBuilder(config, device=device, dtype=dtype).build_model()

+ 4 - 1
src/seamless_communication/models/vocoder/hifigan.py

@@ -128,6 +128,7 @@ class Generator(torch.nn.Module):
         resblock_kernel_sizes: List[int],
         resblock_dilation_sizes: List[List[int]],
         model_in_dim: Optional[int],
+        add_ups_out_pad: bool = False,
     ):
         super(Generator, self).__init__()
         self.num_kernels = len(resblock_kernel_sizes)
@@ -144,6 +145,7 @@ class Generator(torch.nn.Module):
 
         self.ups = nn.ModuleList()
         for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            out_pad = u % 2 if add_ups_out_pad else 0
             self.ups.append(
                 weight_norm(
                     ConvTranspose1d(
@@ -151,7 +153,8 @@ class Generator(torch.nn.Module):
                         upsample_initial_channel // (2 ** (i + 1)),
                         k,
                         u,
-                        padding=(k - u) // 2,
+                        padding=(k - u) // 2 + out_pad,
+                        output_padding=out_pad,
                     )
                 )
             )

+ 17 - 0
src/seamless_communication/models/vocoder/loader.py

@@ -12,9 +12,12 @@ from overrides import override as finaloverride
 
 from seamless_communication.models.vocoder.builder import (
     VocoderConfig,
+    create_mel_vocoder_model,
     create_vocoder_model,
+    mel_vocoder_archs,
     vocoder_archs,
 )
+from seamless_communication.models.vocoder.melhifigan import MelGenerator
 from seamless_communication.models.vocoder.vocoder import Vocoder
 
 
@@ -39,3 +42,17 @@ class VocoderLoader(ModelLoader[Vocoder, VocoderConfig]):
 load_vocoder_model = VocoderLoader(
     asset_store, download_manager, create_vocoder_model, vocoder_archs
 )
+
+
+@final
+class MelVocoderLoader(ModelLoader[MelGenerator, VocoderConfig]):
+    @finaloverride
+    def _convert_checkpoint(
+        self, checkpoint: Mapping[str, Any], config: VocoderConfig
+    ) -> Mapping[str, Any]:
+        return checkpoint
+
+
+load_mel_vocoder_model = MelVocoderLoader(
+    asset_store, download_manager, create_mel_vocoder_model, mel_vocoder_archs
+)

+ 46 - 0
src/seamless_communication/models/vocoder/melhifigan.py

@@ -0,0 +1,46 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List
+
+import torch
+
+from seamless_communication.models.vocoder.hifigan import Generator
+
+
+class MelGenerator(Generator):
+    def __init__(
+        self,
+        upsample_rates: List[int],
+        upsample_kernel_sizes: List[int],
+        upsample_initial_channel: int,
+        resblock_kernel_sizes: List[int],
+        resblock_dilation_sizes: List[List[int]],
+        model_in_dim: int = 80,
+    ):
+        super().__init__(
+            upsample_rates,
+            upsample_kernel_sizes,
+            upsample_initial_channel,
+            resblock_kernel_sizes,
+            resblock_dilation_sizes,
+            model_in_dim,
+            add_ups_out_pad=True,
+        )
+
+        for u, k in zip(upsample_rates, upsample_kernel_sizes):
+            assert k == 2 * u, (k, u)
+
+        mean = torch.zeros((model_in_dim,), dtype=torch.float)
+        scale = torch.zeros((model_in_dim,), dtype=torch.float)
+        self.register_buffer("mean", mean)
+        self.register_buffer("scale", scale)
+
+    def forward(self, x: torch.Tensor, normalize_before: bool = True) -> torch.Tensor:
+        if normalize_before:
+            x = (x - self.mean) / self.scale
+        x = super().forward(x.transpose(1, 0).unsqueeze(0))
+        return x.squeeze(0).transpose(1, 0)

+ 8 - 3
tests/common.py

@@ -5,7 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from contextlib import contextmanager
-from typing import Any, Generator, List, Union
+from typing import Any, Generator, List, Optional, Union
 
 import torch
 from fairseq2.typing import Device
@@ -16,12 +16,17 @@ from torch import Tensor
 device = Device("cpu")
 
 
-def assert_close(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
+def assert_close(
+    a: Tensor,
+    b: Union[Tensor, List[Any]],
+    rtol: Optional[float] = None,
+    atol: Optional[float] = None,
+) -> None:
     """Assert that ``a`` and ``b`` are element-wise equal within a tolerance."""
     if not isinstance(b, Tensor):
         b = torch.tensor(b, device=device, dtype=a.dtype)
 
-    torch.testing.assert_close(a, b)  # type: ignore[attr-defined]
+    torch.testing.assert_close(a, b, rtol=rtol, atol=atol)  # type: ignore[attr-defined]
 
 
 def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None:

+ 49 - 0
tests/integration/models/test_pretssel_vocoder.py

@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import tempfile
+from urllib.request import urlretrieve
+
+import torch
+import torchaudio
+
+from seamless_communication.models.vocoder.loader import load_mel_vocoder_model
+from tests.common import assert_close, device
+
+
+def test_pretssel_vocoder() -> None:
+    n_mel_bins = 80
+    sample_rate = 16_000
+
+    vocoder = load_mel_vocoder_model(
+        "vocoder_pretssel", device=device, dtype=torch.float32
+    )
+
+    url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
+
+    with tempfile.NamedTemporaryFile() as f:
+        urlretrieve(url, f.name)
+        _wav, _sr = torchaudio.load(f.name)
+
+    wav = torchaudio.sox_effects.apply_effects_tensor(
+        _wav, _sr, [["rate", f"{sample_rate}"], ["channels", "1"]]
+    )[0].to(device=device)
+    feat = torchaudio.compliance.kaldi.fbank(
+        wav * (2**15), num_mel_bins=n_mel_bins, sample_frequency=sample_rate
+    )
+
+    with torch.no_grad():
+        wav_hat = vocoder(feat).t()
+
+    feat_hat = torchaudio.compliance.kaldi.fbank(
+        wav_hat * (2**15), num_mel_bins=n_mel_bins, sample_frequency=sample_rate
+    )
+
+    assert_close(feat_hat, feat[: feat_hat.shape[0], :], atol=0.0, rtol=5.0)
+
+
+if __name__ == "__main__":
+    test_pretssel_vocoder()