Просмотр исходного кода

Merge pull request #61 from facebookresearch/audio_to_units_README

Adding m4t_audio_to_units, README for audio_to_units, vocoder.to(dtype).
Kaushik Ram Sadagopan 2 лет назад
Родитель
Сommit
b650cd430d

+ 3 - 0
README.md

@@ -79,6 +79,9 @@ To reproduce our results, or to evaluate using the same metrics over your own te
 ## Finetuning SeamlessM4T models
 Please check out the [README here](scripts/m4t/finetune/README.md).
 
+## Converting raw audio to units
+Please check out the [README here](scripts/m4t/audio_to_units/README.md).
+
 ## On-device models
 Apart from Seamless-M4T large (2.3B) and medium (1.2B) models, we are also releasing a small model (281M) targeted for on-device inference. To learn more about the usage and model details check out the [README here](docs/m4t/on_device_README.md).
 

+ 19 - 0
scripts/m4t/audio_to_units/README.md

@@ -0,0 +1,19 @@
+# Convert raw audio into units (unit_extraction)
+
+Raw audio needs to be converted to units to train UnitY models and vocoders. Units act as supervision for UnitY models, and are the input to the vocoders which synthesize speech from these units.
+
+The unit extraction pipeline comprises the following steps:
+- Compute features from layer 35 (determined empirically) of the pretrained XLSR v2 model, which is a wav2vec2 model at the core.
+- Assign features for each timestep to a collection of precomputed K-Means centroids to produce a sequence of units.
+
+
+## Quick start:
+`audio_to_units` is run with the CLI, from the root directory of the repository.
+
+```bash
+m4t_audio_to_units <path_to_input_audio>
+```
+
+`audio_to_units` calls for `UnitExtractor` which provides a `predict` method to convert an audio to units.
+
+The convenience method `resynthesize_audio` of `UnitExtractor`, can be used to resynthesize audio waveforms from units.

+ 5 - 0
scripts/m4t/audio_to_units/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.

+ 5 - 0
scripts/m4t/finetune/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.

+ 5 - 0
scripts/m4t/predict/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.

+ 1 - 1
scripts/m4t/predict/predict.py

@@ -81,7 +81,7 @@ def main():
         logger.info(f"Saving translated audio in {args.tgt_lang}")
         torchaudio.save(
             args.output_path,
-            wav[0].cpu(),
+            wav[0].to(torch.float32).cpu(),
             sample_rate=sr,
         )
     logger.info(f"Translated text in {args.tgt_lang}: {translated_text}")

+ 1 - 0
setup.py

@@ -66,6 +66,7 @@ setup(
             "m4t_predict=m4t_scripts.predict.predict:main",
             "m4t_finetune=m4t_scripts.finetune.finetune:main",
             "m4t_prepare_dataset=m4t_scripts.finetune.dataset:main",
+            "m4t_audio_to_units=m4t_scripts.audio_to_units.audio_to_units:main",
         ],
     },
     cmdclass={"develop": cmd_for_editable_mode},

+ 2 - 2
src/seamless_communication/models/inference/translator.py

@@ -149,14 +149,14 @@ class Translator(nn.Module):
     @torch.inference_mode()
     def predict(
         self,
-        input: Union[str, torch.Tensor],
+        input: Union[str, Tensor],
         task_str: str,
         tgt_lang: str,
         src_lang: Optional[str] = None,
         spkr: Optional[int] = -1,
         ngram_filtering: bool = False,
         sample_rate: int = 16000,
-    ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
+    ) -> Tuple[StringLike, Optional[Tensor], Optional[int]]:
         """
         The main method used to perform inference on all tasks.
 

+ 1 - 0
src/seamless_communication/models/unit_extraction/__init__.py

@@ -3,6 +3,7 @@
 #
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
+
 from seamless_communication.models.unit_extraction.unit_extraction import (
     UnitExtractor as UnitExtractor,
 )

+ 7 - 2
src/seamless_communication/models/unit_extraction/unit_extraction.py

@@ -50,7 +50,7 @@ class UnitExtractor(nn.Module):
     @torch.inference_mode()
     def predict(
         self,
-        audio: Union[str, torch.Tensor],
+        audio: Union[str, Tensor],
         out_layer_idx: int,
         sample_rate: int = 16000,
     ) -> Tensor:
@@ -74,7 +74,12 @@ class UnitExtractor(nn.Module):
         return units
 
     @staticmethod
-    def resynthesize_audio(units, src_lang, device, vocoder_name="vocoder_36langs"):
+    def resynthesize_audio(
+        units: Tensor,
+        src_lang: str,
+        device: Device,
+        vocoder_name: str = "vocoder_36langs",
+    ) -> Tensor:
         def reduce_list(lst):
             return [key for key, _ in groupby(lst)]
 

+ 3 - 1
src/seamless_communication/models/vocoder/builder.py

@@ -113,7 +113,9 @@ class VocoderBuilder:
             self.config.spkr_embedding_dim,
             self.config.num_spkrs,
         )
-        return Vocoder(code_generator, self.config.lang_spkr_idx_map)
+        vocoder = Vocoder(code_generator, self.config.lang_spkr_idx_map)
+        vocoder.to(dtype=self.dtype)
+        return vocoder
 
 
 def create_vocoder_model(

+ 2 - 2
src/seamless_communication/models/vocoder/vocoder.py

@@ -8,7 +8,7 @@ from typing import List, Optional
 
 import torch
 import torch.nn as nn
-from fairseq2.typing import Device
+from torch import Tensor
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 
@@ -25,7 +25,7 @@ class Vocoder(nn.Module):
         lang: str,
         spkr: Optional[int] = -1,
         dur_prediction: bool = True,
-    ):
+    ) -> Tensor:
         x = {
             "code": torch.LongTensor(code).view(1, -1),
         }