Ver Fonte

Refactoring expressivity/predict into ExpressiveTranslator.

Kaushik Ram Sadagopan há 1 ano atrás
pai
commit
ceada0caff

+ 1 - 1
README.md

@@ -161,7 +161,7 @@ Please check out above [section](#seamlessexpressive-models) on how to acquire `
 ### W2v-BERT 2.0 speech encoder
 | Model Name        | #params | checkpoint                                                                                                                                                                                                                                                                                                                                                                 |
 | ----------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| W2v-BERT 2.0 | 600M    | [🤗 Model card](https://huggingface.co/facebook/conformer-shaw) - [checkpoint](https://huggingface.co/facebook/conformer-shaw/resolve/main/conformer_shaw.pt)
+| W2v-BERT 2.0 | 600M    | [🤗 Model card](https://huggingface.co/facebook/w2v-bert-2.0) - [checkpoint](https://huggingface.co/facebook/w2v-bert-2.0/resolve/main/conformer_shaw.pt)
 
 Here's how you should do a foward pass through the speech encoder:
 

+ 1 - 1
demo/expressive/app.py

@@ -29,7 +29,7 @@ from seamless_communication.models.unity import (
     load_gcmvn_stats,
     load_unity_unit_tokenizer,
 )
-from seamless_communication.cli.expressivity.predict.pretssel_generator import PretsselGenerator
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
 
 from typing import Tuple
 from utils import LANGUAGE_CODE_TO_NAME

+ 4 - 3
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -25,9 +25,7 @@ from fairseq2.typing import DataType, Device
 from torch import Tensor
 from tqdm import tqdm
 
-from seamless_communication.cli.expressivity.predict.pretssel_generator import (
-    PretsselGenerator,
-)
+
 from seamless_communication.cli.m4t.evaluate.evaluate import (
     adjust_output_for_corrupted_inputs,
     count_lines,
@@ -36,6 +34,9 @@ from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
+from seamless_communication.inference.pretssel_generator import (
+    PretsselGenerator,
+)
 from seamless_communication.inference import BatchedSpeechOutput, Translator
 from seamless_communication.models.unity import (
     load_gcmvn_stats,

+ 11 - 90
src/seamless_communication/cli/expressivity/predict/predict.py

@@ -10,27 +10,14 @@ import torch
 import torchaudio
 from pathlib import Path
 
-from fairseq2.data import SequenceData
-from fairseq2.data.audio import WaveformToFbankConverter
-
-from seamless_communication.cli.expressivity.predict.pretssel_generator import (
-    PretsselGenerator,
-)
 from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
-from seamless_communication.inference import Translator
-from seamless_communication.models.unity import (
-    load_gcmvn_stats,
-    load_unity_unit_tokenizer,
-)
+from seamless_communication.inference import ExpressiveTranslator
 from seamless_communication.store import add_gated_assets
 
 
-AUDIO_SAMPLE_RATE = 16000
-
-
 logging.basicConfig(
     level=logging.INFO,
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
@@ -39,13 +26,6 @@ logging.basicConfig(
 logger = logging.getLogger(__name__)
 
 
-def remove_prosody_tokens_from_text(text: str) -> str:
-    # filter out prosody tokens, there is only emphasis '*', and pause '='
-    text = text.replace("*", "").replace("=", "")
-    text = " ".join(text.split())
-    return text
-
-
 def main() -> None:
     parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
     parser.add_argument("input", type=str, help="Audio WAV file path.")
@@ -82,59 +62,11 @@ def main() -> None:
 
     logger.info(f"Running inference on {device=} with {dtype=}.")
 
-    unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
-    
-    translator = Translator(
+    expressive_translator = ExpressiveTranslator(
         args.model_name,
-        vocoder_name_or_card=None,
-        device=device,
-        dtype=dtype,
-    )
-
-    pretssel_generator = PretsselGenerator(
         args.vocoder_name,
-        vocab_info=unit_tokenizer.vocab_info,
-        device=device,
-        dtype=dtype,
-    )
-
-    fbank_extractor = WaveformToFbankConverter(
-        num_mel_bins=80,
-        waveform_scale=2**15,
-        channel_last=True,
-        standardize=False,
-        device=device,
-        dtype=dtype,
-    )
-
-    _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
-    gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
-    gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
-
-    wav, sample_rate = torchaudio.load(args.input)
-    wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
-    wav = wav.transpose(0, 1)
-
-    data = fbank_extractor(
-        {
-            "waveform": wav,
-            "sample_rate": 16000,
-        }
-    )
-    fbank = data["fbank"]
-    gcmvn_fbank = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
-    std, mean = torch.std_mean(fbank, dim=0)
-    fbank = fbank.subtract(mean).divide(std)
-
-    src = SequenceData(
-        seqs=fbank.unsqueeze(0),
-        seq_lens=torch.LongTensor([fbank.shape[0]]),
-        is_ragged=False,
-    )
-    src_gcmvn = SequenceData(
-        seqs=gcmvn_fbank.unsqueeze(0),
-        seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
-        is_ragged=False,
+        device,
+        dtype
     )
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
@@ -145,22 +77,13 @@ def main() -> None:
         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
     )
 
-    text_output, unit_output = translator.predict(
-        src,
-        "s2st",
+    speech_output, text_output = expressive_translator.predict(
+        args.input,
         args.tgt_lang,
-        text_generation_opts=text_generation_opts,
-        unit_generation_opts=unit_generation_opts,
-        unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
-        duration_factor=args.duration_factor,
-        prosody_encoder_input=src_gcmvn,
-    )
-
-    assert unit_output is not None
-    speech_output = pretssel_generator.predict(
-        unit_output.units,
-        tgt_lang=args.tgt_lang,
-        prosody_encoder_input=src_gcmvn,
+        text_generation_opts,
+        unit_generation_opts,
+        args.unit_generation_ngram_filtering,
+        args.duration_factor,
     )
 
     logger.info(f"Saving expressive translated audio in {args.tgt_lang}")
@@ -170,9 +93,7 @@ def main() -> None:
         sample_rate=speech_output.sample_rate,
     )
 
-    text_out = remove_prosody_tokens_from_text(str(text_output[0]))
-
-    logger.info(f"Translated text in {args.tgt_lang}: {text_out}")
+    logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")
 
 
 if __name__ == "__main__":

+ 3 - 0
src/seamless_communication/inference/__init__.py

@@ -8,6 +8,9 @@ from seamless_communication.inference.generator import (
     SequenceGeneratorOptions as SequenceGeneratorOptions,
 )
 from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
+from seamless_communication.inference.expressive_translator import (
+    ExpressiveTranslator as ExpressiveTranslator,
+)
 from seamless_communication.inference.translator import (
     BatchedSpeechOutput as BatchedSpeechOutput,
 )

+ 153 - 0
src/seamless_communication/inference/expressive_translator.py

@@ -0,0 +1,153 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import torch
+import torchaudio
+
+from torch.nn import Module
+from typing import List, Optional, Tuple, Union
+
+from fairseq2.assets.card import AssetCard
+from fairseq2.data import SequenceData, StringLike
+from fairseq2.data.audio import WaveformToFbankConverter
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.inference import BatchedSpeechOutput, Translator
+from seamless_communication.inference.generator import SequenceGeneratorOptions
+from seamless_communication.inference.pretssel_generator import (
+    PretsselGenerator,
+)
+from seamless_communication.models.unity import (
+    load_gcmvn_stats,
+    load_unity_unit_tokenizer,
+)
+
+AUDIO_SAMPLE_RATE = 16000
+
+
+class ExpressiveTranslator(Module):
+    def __init__(
+        self,
+        model_name_or_card: Union[str, AssetCard],
+        vocoder_name_or_card: Union[str, AssetCard, None],
+        device: Device,
+        dtype: DataType,
+    ):
+        super().__init__()
+
+        unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
+    
+        self.translator = Translator(
+            model_name_or_card,
+            vocoder_name_or_card=None,
+            device=device,
+            dtype=dtype,
+        )
+
+        self.pretssel_generator = PretsselGenerator(
+            vocoder_name_or_card,
+            vocab_info=unit_tokenizer.vocab_info,
+            device=device,
+            dtype=dtype,
+        )
+
+        self.fbank_extractor = WaveformToFbankConverter(
+            num_mel_bins=80,
+            waveform_scale=2**15,
+            channel_last=True,
+            standardize=False,
+            device=device,
+            dtype=dtype,
+        )
+
+        _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(vocoder_name_or_card)
+        self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
+        self.gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
+        
+    @staticmethod
+    def remove_prosody_tokens_from_text(text_output: List[str]) -> List[str]:
+        modified_text_output = []
+        for text in text_output:
+            # filter out prosody tokens, there is only emphasis '*', and pause '='
+            text = text.replace("*", "").replace("=", "")
+            text = " ".join(text.split())
+            modified_text_output.append(text)
+        return modified_text_output
+
+    @torch.inference_mode()
+    def predict(
+        self,
+        audio_path: str,
+        tgt_lang: str,
+        text_generation_opts: Optional[SequenceGeneratorOptions] = None,
+        unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
+        unit_generation_ngram_filtering: bool = False,
+        duration_factor: float = 1.0,
+    ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
+        """
+        The main method used to perform inference on all tasks.
+
+        :param audio_path:
+            Path to audio waveform.
+        :param tgt_lang:
+            Target language to decode into.
+        :param text_generation_opts:
+            Text generation hyperparameters for incremental decoding.
+        :param unit_generation_opts:
+            Unit generation hyperparameters for incremental decoding.
+        :param unit_generation_ngram_filtering:
+            If True, removes consecutive repeated ngrams
+            from the decoded unit output.
+
+        :returns:
+            - Batched list of Translated text.
+            - Translated BatchedSpeechOutput.
+        """
+        # TODO: Replace with fairseq2.data once re-sampling is implemented.
+        wav, sample_rate = torchaudio.load(audio_path)
+        wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
+        wav = wav.transpose(0, 1)
+
+        data = self.fbank_extractor(
+            {
+                "waveform": wav,
+                "sample_rate": AUDIO_SAMPLE_RATE,
+            }
+        )
+        fbank = data["fbank"]
+        gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+        std, mean = torch.std_mean(fbank, dim=0)
+        fbank = fbank.subtract(mean).divide(std)
+
+        src = SequenceData(
+            seqs=fbank.unsqueeze(0),
+            seq_lens=torch.LongTensor([fbank.shape[0]]),
+            is_ragged=False,
+        )
+        src_gcmvn = SequenceData(
+            seqs=gcmvn_fbank.unsqueeze(0),
+            seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
+            is_ragged=False,
+        )
+
+        text_output, unit_output = self.translator.predict(
+            src,
+            "s2st",
+            tgt_lang,
+            text_generation_opts=text_generation_opts,
+            unit_generation_opts=unit_generation_opts,
+            unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+            duration_factor=duration_factor,
+            prosody_encoder_input=src_gcmvn,
+        )
+        text_output = self.remove_prosody_tokens_from_text(text_output)
+
+        assert unit_output is not None
+        speech_output = self.pretssel_generator.predict(
+            unit_output.units,
+            tgt_lang=tgt_lang,
+            prosody_encoder_input=src_gcmvn,
+        )
+        return text_output, speech_output

+ 0 - 0
src/seamless_communication/cli/expressivity/predict/pretssel_generator.py → src/seamless_communication/inference/pretssel_generator.py


+ 2 - 2
src/seamless_communication/inference/translator.py

@@ -10,7 +10,7 @@ from pathlib import Path
 from typing import List, Optional, Tuple, Union, cast
 
 import torch
-import torch.nn as nn
+from torch.nn import Module
 from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData, StringLike
@@ -75,7 +75,7 @@ class BatchedSpeechOutput:
     """Sample rate of the audio waveforms."""
 
 
-class Translator(nn.Module):
+class Translator(Module):
     def __init__(
         self,
         model_name_or_card: Union[str, AssetCard],