Pārlūkot izejas kodu

batched expressive_translator

Yilin Yang 1 gadu atpakaļ
vecāks
revīzija
778e11e8c9

+ 14 - 55
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -25,7 +25,6 @@ from fairseq2.typing import DataType, Device
 from torch import Tensor
 from torch import Tensor
 from tqdm import tqdm
 from tqdm import tqdm
 
 
-
 from seamless_communication.cli.m4t.evaluate.evaluate import (
 from seamless_communication.cli.m4t.evaluate.evaluate import (
     adjust_output_for_corrupted_inputs,
     adjust_output_for_corrupted_inputs,
     count_lines,
     count_lines,
@@ -34,14 +33,9 @@ from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     add_inference_arguments,
     set_generation_opts,
     set_generation_opts,
 )
 )
-from seamless_communication.inference.pretssel_generator import (
-    PretsselGenerator,
-)
-from seamless_communication.inference import BatchedSpeechOutput, Translator
-from seamless_communication.models.unity import (
-    load_gcmvn_stats,
-    load_unity_unit_tokenizer,
-)
+from seamless_communication.inference import BatchedSpeechOutput, ExpressiveTranslator
+from seamless_communication.models.unity import load_unity_unit_tokenizer
+
 from seamless_communication.store import add_gated_assets
 from seamless_communication.store import add_gated_assets
 
 
 logging.basicConfig(
 logging.basicConfig(
@@ -56,8 +50,6 @@ def build_data_pipeline(
     args: Namespace,
     args: Namespace,
     device: Device,
     device: Device,
     dtype: DataType,
     dtype: DataType,
-    gcmvn_mean: Tensor,
-    gcmvn_std: Tensor,
 ) -> DataPipeline:
 ) -> DataPipeline:
     with open(args.data_file, "r") as f:
     with open(args.data_file, "r") as f:
         header = f.readline().strip("\n").split("\t")
         header = f.readline().strip("\n").split("\t")
@@ -90,15 +82,8 @@ def build_data_pipeline(
         dtype=dtype,
         dtype=dtype,
     )
     )
 
 
-    def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
-        fbank = data["fbank"]
-        std, mean = torch.std_mean(fbank, dim=0)
-        data["fbank"] = fbank.subtract(mean).divide(std)
-        data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
-        return data
-
     pipeline_builder.map(
     pipeline_builder.map(
-        [decode_audio, convert_to_fbank, normalize_fbank],
+        [decode_audio, convert_to_fbank],
         selector=f"{args.audio_field}.data",
         selector=f"{args.audio_field}.data",
         num_parallel_calls=n_parallel,
         num_parallel_calls=n_parallel,
     )
     )
@@ -177,17 +162,10 @@ def main() -> None:
 
 
     unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
     unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
 
 
-    _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
-    gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
-    gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
-
-    pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
+    pipeline = build_data_pipeline(args, device, dtype)
 
 
-    translator = Translator(
-        args.model_name,
-        vocoder_name_or_card=None,
-        device=device,
-        dtype=dtype,
+    expressive_translator = ExpressiveTranslator(
+        args.model_name, args.vocoder_name, device, dtype
     )
     )
 
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
@@ -198,13 +176,6 @@ def main() -> None:
         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
     )
     )
 
 
-    pretssel_generator = PretsselGenerator(
-        args.vocoder_name,
-        vocab_info=unit_tokenizer.vocab_info,
-        device=device,
-        dtype=dtype,
-    )
-
     total_steps = count_lines(args.data_file) - 1
     total_steps = count_lines(args.data_file) - 1
     progress_bar = tqdm(total=total_steps)
     progress_bar = tqdm(total=total_steps)
 
 
@@ -241,28 +212,16 @@ def main() -> None:
                 src["seqs"] = src["seqs"][valid_sequences]
                 src["seqs"] = src["seqs"][valid_sequences]
                 src["seq_lens"] = src["seq_lens"][valid_sequences]
                 src["seq_lens"] = src["seq_lens"][valid_sequences]
 
 
-            # Skip performing inference when the input is entirely corrupted.
+            # Skip inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
             if src["seqs"].numel() > 0:
-                prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
-                text_output, unit_output = translator.predict(
+                text_output, speech_output = expressive_translator.predict(
                     src,
                     src,
-                    "s2st",
                     args.tgt_lang,
                     args.tgt_lang,
-                    src_lang=args.src_lang,
-                    text_generation_opts=text_generation_opts,
-                    unit_generation_opts=unit_generation_opts,
-                    unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
-                    duration_factor=args.duration_factor,
-                    prosody_encoder_input=prosody_encoder_input,
+                    text_generation_opts,
+                    unit_generation_opts,
+                    args.unit_generation_ngram_filtering,
+                    args.duration_factor,
                 )
                 )
-
-                assert unit_output is not None
-                speech_output = pretssel_generator.predict(
-                    unit_output.units,
-                    tgt_lang=args.tgt_lang,
-                    prosody_encoder_input=prosody_encoder_input,
-                )
-
             else:
             else:
                 text_output = []
                 text_output = []
                 speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
                 speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
@@ -274,7 +233,7 @@ def main() -> None:
                     speech_output,
                     speech_output,
                 )
                 )
 
 
-            hyps += [str(s) for s in text_output]
+            hyps += [s for s in text_output]
             if args.ref_field is not None and args.ref_field in example:
             if args.ref_field is not None and args.ref_field in example:
                 refs += [str(s) for s in example[args.ref_field]]
                 refs += [str(s) for s in example[args.ref_field]]
 
 

+ 10 - 11
src/seamless_communication/cli/expressivity/predict/predict.py

@@ -6,9 +6,10 @@
 
 
 import argparse
 import argparse
 import logging
 import logging
+from pathlib import Path
+
 import torch
 import torch
 import torchaudio
 import torchaudio
-from pathlib import Path
 
 
 from seamless_communication.cli.m4t.predict import (
 from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     add_inference_arguments,
@@ -17,7 +18,6 @@ from seamless_communication.cli.m4t.predict import (
 from seamless_communication.inference import ExpressiveTranslator
 from seamless_communication.inference import ExpressiveTranslator
 from seamless_communication.store import add_gated_assets
 from seamless_communication.store import add_gated_assets
 
 
-
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
@@ -27,8 +27,10 @@ logger = logging.getLogger(__name__)
 
 
 
 
 def main() -> None:
 def main() -> None:
-    parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
-    parser.add_argument("input", type=str, help="Audio WAV file path.")
+    parser = argparse.ArgumentParser(
+        description="Running SeamlessExpressive inference."
+    )
+    parser.add_argument("input", type=Path, help="Audio WAV file path.")
 
 
     parser = add_inference_arguments(parser)
     parser = add_inference_arguments(parser)
     parser.add_argument(
     parser.add_argument(
@@ -49,10 +51,10 @@ def main() -> None:
         raise Exception(
         raise Exception(
             "--tgt_lang, --output_path must be provided for SeamlessExpressive inference."
             "--tgt_lang, --output_path must be provided for SeamlessExpressive inference."
         )
         )
-        
+
     if args.gated_model_dir:
     if args.gated_model_dir:
         add_gated_assets(args.gated_model_dir)
         add_gated_assets(args.gated_model_dir)
-    
+
     if torch.cuda.is_available():
     if torch.cuda.is_available():
         device = torch.device("cuda:0")
         device = torch.device("cuda:0")
         dtype = torch.float16
         dtype = torch.float16
@@ -63,10 +65,7 @@ def main() -> None:
     logger.info(f"Running inference on {device=} with {dtype=}.")
     logger.info(f"Running inference on {device=} with {dtype=}.")
 
 
     expressive_translator = ExpressiveTranslator(
     expressive_translator = ExpressiveTranslator(
-        args.model_name,
-        args.vocoder_name,
-        device,
-        dtype
+        args.model_name, args.vocoder_name, device, dtype
     )
     )
 
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
@@ -77,7 +76,7 @@ def main() -> None:
         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
     )
     )
 
 
-    speech_output, text_output = expressive_translator.predict(
+    text_output, speech_output = expressive_translator.predict(
         args.input,
         args.input,
         args.tgt_lang,
         args.tgt_lang,
         text_generation_opts,
         text_generation_opts,

+ 3 - 3
src/seamless_communication/inference/__init__.py

@@ -4,13 +4,13 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
+from seamless_communication.inference.expressive_translator import (
+    ExpressiveTranslator as ExpressiveTranslator,
+)
 from seamless_communication.inference.generator import (
 from seamless_communication.inference.generator import (
     SequenceGeneratorOptions as SequenceGeneratorOptions,
     SequenceGeneratorOptions as SequenceGeneratorOptions,
 )
 )
 from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
 from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
-from seamless_communication.inference.expressive_translator import (
-    ExpressiveTranslator as ExpressiveTranslator,
-)
 from seamless_communication.inference.translator import (
 from seamless_communication.inference.translator import (
     BatchedSpeechOutput as BatchedSpeechOutput,
     BatchedSpeechOutput as BatchedSpeechOutput,
 )
 )

+ 56 - 40
src/seamless_communication/inference/expressive_translator.py

@@ -3,22 +3,22 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-import torch
-import torchaudio
-
-from torch.nn import Module
+from typing import cast
+from copy import deepcopy
+from pathlib import Path
 from typing import List, Optional, Tuple, Union
 from typing import List, Optional, Tuple, Union
 
 
+import torch
+import torchaudio
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import SequenceData, StringLike
 from fairseq2.data import SequenceData, StringLike
 from fairseq2.data.audio import WaveformToFbankConverter
 from fairseq2.data.audio import WaveformToFbankConverter
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
+from torch.nn import Module
 
 
-from seamless_communication.inference import BatchedSpeechOutput, Translator
 from seamless_communication.inference.generator import SequenceGeneratorOptions
 from seamless_communication.inference.generator import SequenceGeneratorOptions
-from seamless_communication.inference.pretssel_generator import (
-    PretsselGenerator,
-)
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
+from seamless_communication.inference.translator import BatchedSpeechOutput, Translator
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     load_gcmvn_stats,
     load_gcmvn_stats,
     load_unity_unit_tokenizer,
     load_unity_unit_tokenizer,
@@ -38,7 +38,7 @@ class ExpressiveTranslator(Module):
         super().__init__()
         super().__init__()
 
 
         unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
         unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
-    
+
         self.translator = Translator(
         self.translator = Translator(
             model_name_or_card,
             model_name_or_card,
             vocoder_name_or_card=None,
             vocoder_name_or_card=None,
@@ -65,13 +65,13 @@ class ExpressiveTranslator(Module):
         _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(vocoder_name_or_card)
         _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(vocoder_name_or_card)
         self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
         self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
         self.gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
         self.gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
-        
+
     @staticmethod
     @staticmethod
     def remove_prosody_tokens_from_text(text_output: List[str]) -> List[str]:
     def remove_prosody_tokens_from_text(text_output: List[str]) -> List[str]:
         modified_text_output = []
         modified_text_output = []
         for text in text_output:
         for text in text_output:
             # filter out prosody tokens, there is only emphasis '*', and pause '='
             # filter out prosody tokens, there is only emphasis '*', and pause '='
-            text = text.replace("*", "").replace("=", "")
+            text = str(text).replace("*", "").replace("=", "")
             text = " ".join(text.split())
             text = " ".join(text.split())
             modified_text_output.append(text)
             modified_text_output.append(text)
         return modified_text_output
         return modified_text_output
@@ -79,7 +79,7 @@ class ExpressiveTranslator(Module):
     @torch.inference_mode()
     @torch.inference_mode()
     def predict(
     def predict(
         self,
         self,
-        audio_path: str,
+        input: Union[Path, SequenceData],
         tgt_lang: str,
         tgt_lang: str,
         text_generation_opts: Optional[SequenceGeneratorOptions] = None,
         text_generation_opts: Optional[SequenceGeneratorOptions] = None,
         unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
         unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
@@ -89,8 +89,8 @@ class ExpressiveTranslator(Module):
         """
         """
         The main method used to perform inference on all tasks.
         The main method used to perform inference on all tasks.
 
 
-        :param audio_path:
-            Path to audio waveform.
+        :param input:
+            Either path to audio or audio Tensor.
         :param tgt_lang:
         :param tgt_lang:
             Target language to decode into.
             Target language to decode into.
         :param text_generation_opts:
         :param text_generation_opts:
@@ -105,32 +105,48 @@ class ExpressiveTranslator(Module):
             - Batched list of Translated text.
             - Batched list of Translated text.
             - Translated BatchedSpeechOutput.
             - Translated BatchedSpeechOutput.
         """
         """
-        # TODO: Replace with fairseq2.data once re-sampling is implemented.
-        wav, sample_rate = torchaudio.load(audio_path)
-        wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
-        wav = wav.transpose(0, 1)
-
-        data = self.fbank_extractor(
-            {
-                "waveform": wav,
-                "sample_rate": AUDIO_SAMPLE_RATE,
-            }
-        )
-        fbank = data["fbank"]
-        gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
-        std, mean = torch.std_mean(fbank, dim=0)
-        fbank = fbank.subtract(mean).divide(std)
-
-        src = SequenceData(
-            seqs=fbank.unsqueeze(0),
-            seq_lens=torch.LongTensor([fbank.shape[0]]),
-            is_ragged=False,
-        )
-        src_gcmvn = SequenceData(
-            seqs=gcmvn_fbank.unsqueeze(0),
-            seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
-            is_ragged=False,
-        )
+        if isinstance(input, dict):
+            src = cast(SequenceData, input)
+            src_gcmvn = deepcopy(src)
+
+            fbank = src["seqs"]
+            gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+            std, mean = torch.std_mean(fbank, dim=[0, 1]) # B x T x mel
+            ucmvn_fbank = fbank.subtract(mean).divide(std)
+
+            src["seqs"] = ucmvn_fbank
+            src_gcmvn["seqs"] = gcmvn_fbank
+
+        elif isinstance(input, Path):
+            # TODO: Replace with fairseq2.data once re-sampling is implemented.
+            wav, sample_rate = torchaudio.load(path)
+            wav = torchaudio.functional.resample(
+                wav, orig_freq=sample_rate, new_freq=AUDIO_SAMPLE_RATE,
+            )
+            wav = wav.transpose(0, 1)
+
+            data = self.fbank_extractor(
+                {
+                    "waveform": wav,
+                    "sample_rate": AUDIO_SAMPLE_RATE,
+                }
+            )
+
+            fbank = data["fbank"]
+            gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+            std, mean = torch.std_mean(fbank, dim=0)
+            fbank = fbank.subtract(mean).divide(std)
+
+            src = SequenceData(
+                seqs=fbank.unsqueeze(0),
+                seq_lens=torch.LongTensor([fbank.shape[0]]),
+                is_ragged=False,
+            )
+            src_gcmvn = SequenceData(
+                seqs=gcmvn_fbank.unsqueeze(0),
+                seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
+                is_ragged=False,
+            )
 
 
         text_output, unit_output = self.translator.predict(
         text_output, unit_output = self.translator.predict(
             src,
             src,

+ 1 - 1
src/seamless_communication/inference/pretssel_generator.py

@@ -18,7 +18,7 @@ from fairseq2.data import (
 )
 )
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 
 
-from seamless_communication.inference import BatchedSpeechOutput
+from seamless_communication.inference.translator import BatchedSpeechOutput
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 
 
 
 

+ 1 - 1
src/seamless_communication/toxicity/mintox.py

@@ -12,7 +12,7 @@ import torch
 from torch.nn import functional as F
 from torch.nn import functional as F
 
 
 
 
-from seamless_communication.inference import SequenceGeneratorOptions
+from seamless_communication.inference.generator import SequenceGeneratorOptions
 from seamless_communication.toxicity.etox_bad_word_checker import (
 from seamless_communication.toxicity.etox_bad_word_checker import (
     ETOXBadWordChecker,
     ETOXBadWordChecker,
 )
 )