Ver código fonte

batched expressive_translator

Yilin Yang 1 ano atrás
pai
commit
778e11e8c9

+ 14 - 55
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -25,7 +25,6 @@ from fairseq2.typing import DataType, Device
 from torch import Tensor
 from tqdm import tqdm
 
-
 from seamless_communication.cli.m4t.evaluate.evaluate import (
     adjust_output_for_corrupted_inputs,
     count_lines,
@@ -34,14 +33,9 @@ from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
-from seamless_communication.inference.pretssel_generator import (
-    PretsselGenerator,
-)
-from seamless_communication.inference import BatchedSpeechOutput, Translator
-from seamless_communication.models.unity import (
-    load_gcmvn_stats,
-    load_unity_unit_tokenizer,
-)
+from seamless_communication.inference import BatchedSpeechOutput, ExpressiveTranslator
+from seamless_communication.models.unity import load_unity_unit_tokenizer
+
 from seamless_communication.store import add_gated_assets
 
 logging.basicConfig(
@@ -56,8 +50,6 @@ def build_data_pipeline(
     args: Namespace,
     device: Device,
     dtype: DataType,
-    gcmvn_mean: Tensor,
-    gcmvn_std: Tensor,
 ) -> DataPipeline:
     with open(args.data_file, "r") as f:
         header = f.readline().strip("\n").split("\t")
@@ -90,15 +82,8 @@ def build_data_pipeline(
         dtype=dtype,
     )
 
-    def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
-        fbank = data["fbank"]
-        std, mean = torch.std_mean(fbank, dim=0)
-        data["fbank"] = fbank.subtract(mean).divide(std)
-        data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
-        return data
-
     pipeline_builder.map(
-        [decode_audio, convert_to_fbank, normalize_fbank],
+        [decode_audio, convert_to_fbank],
         selector=f"{args.audio_field}.data",
         num_parallel_calls=n_parallel,
     )
@@ -177,17 +162,10 @@ def main() -> None:
 
     unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
 
-    _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
-    gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
-    gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
-
-    pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
+    pipeline = build_data_pipeline(args, device, dtype)
 
-    translator = Translator(
-        args.model_name,
-        vocoder_name_or_card=None,
-        device=device,
-        dtype=dtype,
+    expressive_translator = ExpressiveTranslator(
+        args.model_name, args.vocoder_name, device, dtype
     )
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
@@ -198,13 +176,6 @@ def main() -> None:
         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
     )
 
-    pretssel_generator = PretsselGenerator(
-        args.vocoder_name,
-        vocab_info=unit_tokenizer.vocab_info,
-        device=device,
-        dtype=dtype,
-    )
-
     total_steps = count_lines(args.data_file) - 1
     progress_bar = tqdm(total=total_steps)
 
@@ -241,28 +212,16 @@ def main() -> None:
                 src["seqs"] = src["seqs"][valid_sequences]
                 src["seq_lens"] = src["seq_lens"][valid_sequences]
 
-            # Skip performing inference when the input is entirely corrupted.
+            # Skip inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
-                text_output, unit_output = translator.predict(
+                text_output, speech_output = expressive_translator.predict(
                     src,
-                    "s2st",
                     args.tgt_lang,
-                    src_lang=args.src_lang,
-                    text_generation_opts=text_generation_opts,
-                    unit_generation_opts=unit_generation_opts,
-                    unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
-                    duration_factor=args.duration_factor,
-                    prosody_encoder_input=prosody_encoder_input,
+                    text_generation_opts,
+                    unit_generation_opts,
+                    args.unit_generation_ngram_filtering,
+                    args.duration_factor,
                 )
-
-                assert unit_output is not None
-                speech_output = pretssel_generator.predict(
-                    unit_output.units,
-                    tgt_lang=args.tgt_lang,
-                    prosody_encoder_input=prosody_encoder_input,
-                )
-
             else:
                 text_output = []
                 speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
@@ -274,7 +233,7 @@ def main() -> None:
                     speech_output,
                 )
 
-            hyps += [str(s) for s in text_output]
+            hyps += [s for s in text_output]
             if args.ref_field is not None and args.ref_field in example:
                 refs += [str(s) for s in example[args.ref_field]]
 

+ 10 - 11
src/seamless_communication/cli/expressivity/predict/predict.py

@@ -6,9 +6,10 @@
 
 import argparse
 import logging
+from pathlib import Path
+
 import torch
 import torchaudio
-from pathlib import Path
 
 from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
@@ -17,7 +18,6 @@ from seamless_communication.cli.m4t.predict import (
 from seamless_communication.inference import ExpressiveTranslator
 from seamless_communication.store import add_gated_assets
 
-
 logging.basicConfig(
     level=logging.INFO,
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
@@ -27,8 +27,10 @@ logger = logging.getLogger(__name__)
 
 
 def main() -> None:
-    parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
-    parser.add_argument("input", type=str, help="Audio WAV file path.")
+    parser = argparse.ArgumentParser(
+        description="Running SeamlessExpressive inference."
+    )
+    parser.add_argument("input", type=Path, help="Audio WAV file path.")
 
     parser = add_inference_arguments(parser)
     parser.add_argument(
@@ -49,10 +51,10 @@ def main() -> None:
         raise Exception(
             "--tgt_lang, --output_path must be provided for SeamlessExpressive inference."
         )
-        
+
     if args.gated_model_dir:
         add_gated_assets(args.gated_model_dir)
-    
+
     if torch.cuda.is_available():
         device = torch.device("cuda:0")
         dtype = torch.float16
@@ -63,10 +65,7 @@ def main() -> None:
     logger.info(f"Running inference on {device=} with {dtype=}.")
 
     expressive_translator = ExpressiveTranslator(
-        args.model_name,
-        args.vocoder_name,
-        device,
-        dtype
+        args.model_name, args.vocoder_name, device, dtype
     )
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
@@ -77,7 +76,7 @@ def main() -> None:
         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
     )
 
-    speech_output, text_output = expressive_translator.predict(
+    text_output, speech_output = expressive_translator.predict(
         args.input,
         args.tgt_lang,
         text_generation_opts,

+ 3 - 3
src/seamless_communication/inference/__init__.py

@@ -4,13 +4,13 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
+from seamless_communication.inference.expressive_translator import (
+    ExpressiveTranslator as ExpressiveTranslator,
+)
 from seamless_communication.inference.generator import (
     SequenceGeneratorOptions as SequenceGeneratorOptions,
 )
 from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
-from seamless_communication.inference.expressive_translator import (
-    ExpressiveTranslator as ExpressiveTranslator,
-)
 from seamless_communication.inference.translator import (
     BatchedSpeechOutput as BatchedSpeechOutput,
 )

+ 56 - 40
src/seamless_communication/inference/expressive_translator.py

@@ -3,22 +3,22 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-import torch
-import torchaudio
-
-from torch.nn import Module
+from typing import cast
+from copy import deepcopy
+from pathlib import Path
 from typing import List, Optional, Tuple, Union
 
+import torch
+import torchaudio
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import SequenceData, StringLike
 from fairseq2.data.audio import WaveformToFbankConverter
 from fairseq2.typing import DataType, Device
+from torch.nn import Module
 
-from seamless_communication.inference import BatchedSpeechOutput, Translator
 from seamless_communication.inference.generator import SequenceGeneratorOptions
-from seamless_communication.inference.pretssel_generator import (
-    PretsselGenerator,
-)
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
+from seamless_communication.inference.translator import BatchedSpeechOutput, Translator
 from seamless_communication.models.unity import (
     load_gcmvn_stats,
     load_unity_unit_tokenizer,
@@ -38,7 +38,7 @@ class ExpressiveTranslator(Module):
         super().__init__()
 
         unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
-    
+
         self.translator = Translator(
             model_name_or_card,
             vocoder_name_or_card=None,
@@ -65,13 +65,13 @@ class ExpressiveTranslator(Module):
         _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(vocoder_name_or_card)
         self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
         self.gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
-        
+
     @staticmethod
     def remove_prosody_tokens_from_text(text_output: List[str]) -> List[str]:
         modified_text_output = []
         for text in text_output:
             # filter out prosody tokens, there is only emphasis '*', and pause '='
-            text = text.replace("*", "").replace("=", "")
+            text = str(text).replace("*", "").replace("=", "")
             text = " ".join(text.split())
             modified_text_output.append(text)
         return modified_text_output
@@ -79,7 +79,7 @@ class ExpressiveTranslator(Module):
     @torch.inference_mode()
     def predict(
         self,
-        audio_path: str,
+        input: Union[Path, SequenceData],
         tgt_lang: str,
         text_generation_opts: Optional[SequenceGeneratorOptions] = None,
         unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
@@ -89,8 +89,8 @@ class ExpressiveTranslator(Module):
         """
         The main method used to perform inference on all tasks.
 
-        :param audio_path:
-            Path to audio waveform.
+        :param input:
+            Either path to audio or audio Tensor.
         :param tgt_lang:
             Target language to decode into.
         :param text_generation_opts:
@@ -105,32 +105,48 @@ class ExpressiveTranslator(Module):
             - Batched list of Translated text.
             - Translated BatchedSpeechOutput.
         """
-        # TODO: Replace with fairseq2.data once re-sampling is implemented.
-        wav, sample_rate = torchaudio.load(audio_path)
-        wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
-        wav = wav.transpose(0, 1)
-
-        data = self.fbank_extractor(
-            {
-                "waveform": wav,
-                "sample_rate": AUDIO_SAMPLE_RATE,
-            }
-        )
-        fbank = data["fbank"]
-        gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
-        std, mean = torch.std_mean(fbank, dim=0)
-        fbank = fbank.subtract(mean).divide(std)
-
-        src = SequenceData(
-            seqs=fbank.unsqueeze(0),
-            seq_lens=torch.LongTensor([fbank.shape[0]]),
-            is_ragged=False,
-        )
-        src_gcmvn = SequenceData(
-            seqs=gcmvn_fbank.unsqueeze(0),
-            seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
-            is_ragged=False,
-        )
+        if isinstance(input, dict):
+            src = cast(SequenceData, input)
+            src_gcmvn = deepcopy(src)
+
+            fbank = src["seqs"]
+            gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+            std, mean = torch.std_mean(fbank, dim=[0, 1]) # B x T x mel
+            ucmvn_fbank = fbank.subtract(mean).divide(std)
+
+            src["seqs"] = ucmvn_fbank
+            src_gcmvn["seqs"] = gcmvn_fbank
+
+        elif isinstance(input, Path):
+            # TODO: Replace with fairseq2.data once re-sampling is implemented.
+            wav, sample_rate = torchaudio.load(path)
+            wav = torchaudio.functional.resample(
+                wav, orig_freq=sample_rate, new_freq=AUDIO_SAMPLE_RATE,
+            )
+            wav = wav.transpose(0, 1)
+
+            data = self.fbank_extractor(
+                {
+                    "waveform": wav,
+                    "sample_rate": AUDIO_SAMPLE_RATE,
+                }
+            )
+
+            fbank = data["fbank"]
+            gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+            std, mean = torch.std_mean(fbank, dim=0)
+            fbank = fbank.subtract(mean).divide(std)
+
+            src = SequenceData(
+                seqs=fbank.unsqueeze(0),
+                seq_lens=torch.LongTensor([fbank.shape[0]]),
+                is_ragged=False,
+            )
+            src_gcmvn = SequenceData(
+                seqs=gcmvn_fbank.unsqueeze(0),
+                seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
+                is_ragged=False,
+            )
 
         text_output, unit_output = self.translator.predict(
             src,

+ 1 - 1
src/seamless_communication/inference/pretssel_generator.py

@@ -18,7 +18,7 @@ from fairseq2.data import (
 )
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 
-from seamless_communication.inference import BatchedSpeechOutput
+from seamless_communication.inference.translator import BatchedSpeechOutput
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 
 

+ 1 - 1
src/seamless_communication/toxicity/mintox.py

@@ -12,7 +12,7 @@ import torch
 from torch.nn import functional as F
 
 
-from seamless_communication.inference import SequenceGeneratorOptions
+from seamless_communication.inference.generator import SequenceGeneratorOptions
 from seamless_communication.toxicity.etox_bad_word_checker import (
     ETOXBadWordChecker,
 )