瀏覽代碼

Enabling 24khz vocoder for demo/OSS (#132)

* enabling 24khz vocoder for demo/OSS

* move the sample_rate to model card
Yilin Yang 1 年之前
父節點
當前提交
00118c21cc

+ 2 - 1
src/seamless_communication/cards/vocoder_mel.yaml

@@ -7,4 +7,5 @@
 name: vocoder_mel
 name: vocoder_mel
 model_type: vocoder_mel_hifigan
 model_type: vocoder_mel_hifigan
 model_arch: base_mel
 model_arch: base_mel
-checkpoint: "file:///large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
+checkpoint: "file:///large_experiments/seamless/workstream/expressivity/oss/checkpoints/16khz_pretssel_hifigan.pt"
+sample_rate: 16000

+ 11 - 0
src/seamless_communication/cards/vocoder_mel_24khz.yaml

@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: vocoder_mel_24khz
+model_type: vocoder_mel_hifigan
+model_arch: 24khz_mel
+checkpoint: "file:///large_experiments/seamless/workstream/expressivity/oss/checkpoints/24khz_pretssel_hifigan.pt"
+sample_rate: 24000

+ 5 - 2
src/seamless_communication/inference/pretssel_generator.py

@@ -7,6 +7,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
+from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData
 from fairseq2.data import Collater, SequenceData
 from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
 from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
@@ -53,6 +54,9 @@ class PretsselGenerator(nn.Module):
         )
         )
         self.pretssel_model.eval()
         self.pretssel_model.eval()
 
 
+        vocoder_model_card = asset_store.retrieve_card(vocoder_name_or_card)
+        self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
+
         self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
         self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
         self.unit_collate = Collater(pad_value=self.unit_tokenizer.vocab_info.pad_idx)
         self.unit_collate = Collater(pad_value=self.unit_tokenizer.vocab_info.pad_idx)
         self.duration_collate = Collater(pad_value=0)
         self.duration_collate = Collater(pad_value=0)
@@ -78,7 +82,6 @@ class PretsselGenerator(nn.Module):
         units: List[List[int]],
         units: List[List[int]],
         tgt_lang: str,
         tgt_lang: str,
         prosody_encoder_input: SequenceData,
         prosody_encoder_input: SequenceData,
-        sample_rate: int = 16000,
     ) -> BatchedSpeechOutput:
     ) -> BatchedSpeechOutput:
         list_units, durations = [], []
         list_units, durations = [], []
         unit_eos_token = torch.tensor(
         unit_eos_token = torch.tensor(
@@ -130,5 +133,5 @@ class PretsselGenerator(nn.Module):
         return BatchedSpeechOutput(
         return BatchedSpeechOutput(
             units=units,
             units=units,
             audio_wavs=audio_wavs,
             audio_wavs=audio_wavs,
-            sample_rate=sample_rate,
+            sample_rate=self.output_sample_rate,
         )
         )

+ 3 - 0
src/seamless_communication/models/vocoder/__init__.py

@@ -15,6 +15,9 @@ from seamless_communication.models.vocoder.codehifigan import (
     CodeGenerator as CodeGenerator,
     CodeGenerator as CodeGenerator,
 )
 )
 from seamless_communication.models.vocoder.hifigan import Generator as Generator
 from seamless_communication.models.vocoder.hifigan import Generator as Generator
+from seamless_communication.models.vocoder.loader import (
+    load_mel_vocoder_config as load_mel_vocoder_config,
+)
 from seamless_communication.models.vocoder.loader import (
 from seamless_communication.models.vocoder.loader import (
     load_mel_vocoder_model as load_mel_vocoder_model,
     load_mel_vocoder_model as load_mel_vocoder_model,
 )
 )

+ 20 - 0
src/seamless_communication/models/vocoder/builder.py

@@ -162,6 +162,26 @@ def _base_mel_vocoder() -> VocoderConfig:
     )
     )
 
 
 
 
+@mel_vocoder_arch("24khz_mel")
+def _base_mel_vocoder() -> VocoderConfig:
+    return VocoderConfig(
+        upsample_rates=[5, 4, 4, 3],
+        upsample_kernel_sizes=[10, 8, 8, 6],
+        upsample_initial_channel=512,
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        model_in_dim=80,
+        num_embeddings=0,
+        embedding_dim=0,
+        dur_predictor_params={},
+        lang_embedding_dim=0,
+        num_langs=0,
+        spkr_embedding_dim=0,
+        num_spkrs=0,
+        lang_spkr_idx_map={},
+    )
+
+
 class MelVocoderBuilder:
 class MelVocoderBuilder:
     config: VocoderConfig
     config: VocoderConfig
     device: Optional[Device]
     device: Optional[Device]