浏览代码

Working script with vocoder resynthesis.

Kaushik Ram Sadagopan 2 年之前
父节点
当前提交
f3e1c591ad

+ 91 - 0
scripts/m4t/audio_to_units/audio_to_units.py

@@ -0,0 +1,91 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import torch
+import torchaudio
+from seamless_communication.models.unit_extraction import UnitExtractor
+from seamless_communication.models.inference import Translator
+from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
+from itertools import groupby
+
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Convert raw audio to units (and optionally audio) using UnitExtractor."
+    )
+    parser.add_argument("audio", type=str, help="Audio WAV file path.")
+    parser.add_argument(
+        "--kmeans_uri",
+        type=str,
+        help="URL path to the K-Means model.",
+        default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
+    )
+    parser.add_argument(
+        "--model_name",
+        type=str,
+        help="Feature extraction model name (`xlsr2_1b_v2`)",
+        default="xlsr2_1b_v2",
+    )
+    parser.add_argument(
+        "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
+    )
+    parser.add_argument(
+        "--out_layer_number",
+        type=int,
+        help="Layer number of the feature extraction model to pull out features from.",
+        default=35,
+    )
+    parser.add_argument(
+        "--output_path",
+        type=str,
+        help="Path to save the generated audio.",
+        default=None,
+    )
+    parser.add_argument(
+        "--src_lang", type=str, help="Source language of the audio.", default=None
+    )
+
+    args = parser.parse_args()
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        logger.info("Running unit_extraction on the GPU.")
+    else:
+        device = torch.device("cpu")
+        logger.info("Running unit_extraction on the CPU.")
+
+    unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
+    units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
+
+    if args.output_path is not None:
+
+        if args.src_lang is None:
+            raise ValueError("src_lang must be provided to resynthesize the audio.")
+
+        def reduce_list(lst):
+            return [key for key, _ in groupby(lst)]
+
+        reduced_units = reduce_list(units.cpu().tolist())
+
+        vocoder: Vocoder = Translator.load_model_for_inference(
+            load_vocoder_model, args.vocoder_name, device, torch.float32
+        )
+        wav = vocoder(reduced_units, args.src_lang, spkr=-1, dur_prediction=True)
+
+        torchaudio.save(
+            args.output_path,
+            wav[0].cpu(),
+            sample_rate=16000,
+        )
+
+
+if __name__ == "__main__":
+    main()

+ 3 - 13
src/seamless_communication/models/inference/translator.py

@@ -73,7 +73,7 @@ class Translator(nn.Module):
             pad_idx=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
         # Load the vocoder.
-        self.vocoder = self.load_model_for_inference(
+        self.vocoder: Vocoder = self.load_model_for_inference(
             load_vocoder_model, vocoder_name_or_card, device, torch.float32
         )
         self.sr = sample_rate
@@ -131,16 +131,6 @@ class Translator(nn.Module):
         else:
             return Modality.TEXT, Modality.SPEECH
 
-    @torch.no_grad()
-    def synthesize_speech(
-        self,
-        code: List[int],
-        lang: str,
-        speaker: Optional[int] = None,
-        dur_prediction: Optional[bool] = True,
-    ) -> Tuple[List[Tensor], int]:
-        return self.vocoder(code, lang, speaker, dur_prediction), self.sr
-
     @torch.no_grad()
     def predict(
         self,
@@ -216,5 +206,5 @@ class Translator(nn.Module):
             return text_out.sentences[0], None, None
         else:
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
-            wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
-            return text_out.sentences[0], wav_out, sr_out
+            wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
+            return text_out.sentences[0], wav_out, self.sr

+ 3 - 0
src/seamless_communication/models/unit_extraction/__init__.py

@@ -3,3 +3,6 @@
 #
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
+from seamless_communication.models.unit_extraction.unit_extraction import (
+    UnitExtractor as UnitExtractor,
+)

+ 1 - 11
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -4,7 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import List, Optional, Tuple, Union
+from typing import List, Tuple, Union
 from pathlib import Path
 import torch
 
@@ -68,13 +68,3 @@ class UnitExtractor(nn.Module):
         features = self.model(batch, out_layer_idx).squeeze(0)
         units = self.kmeans_model(features)
         return units
-
-
-if __name__ == "__main__":
-    kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
-    audio = "/large_experiments/seamless/ust/data/TTS/vocoder_training/audio_wavs/multi_spkr/eng/eng_LJSpeech-1.1_0/LJ003-0001.wav"
-    device = torch.device("cuda:1")
-    unit_extractor = UnitExtractor("xlsr2_1b_v2", kmeans_uri, device=Device("cuda:0"))
-    out_layer_number = 35
-    units = unit_extractor.predict(audio, out_layer_number - 1)
-    print(units.shape, units.dtype, units.device)