| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 | 
							- #!/usr/bin/env python
 
- # Copyright (c) Meta Platforms, Inc. and affiliates
 
- # All rights reserved.
 
- #
 
- # This source code is licensed under the license found in the
 
- # MIT_LICENSE file in the root directory of this source tree.
 
- import os
 
- import pathlib
 
- import tempfile
 
- import gradio as gr
 
- import torch
 
- import torchaudio
 
- from fairseq2.assets import InProcAssetMetadataProvider, asset_store
 
- from fairseq2.data import Collater, SequenceData, VocabularyInfo
 
- from fairseq2.data.audio import (
 
-     AudioDecoder,
 
-     WaveformToFbankConverter,
 
-     WaveformToFbankOutput,
 
- )
 
- from seamless_communication.inference import SequenceGeneratorOptions
 
- from fairseq2.generation import NGramRepeatBlockProcessor
 
- from fairseq2.memory import MemoryBlock
 
- from fairseq2.typing import DataType, Device
 
- from huggingface_hub import snapshot_download
 
- from seamless_communication.inference import BatchedSpeechOutput, Translator, SequenceGeneratorOptions
 
- from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 
- from seamless_communication.models.unity import (
 
-     UnitTokenizer,
 
-     load_gcmvn_stats,
 
-     load_unity_text_tokenizer,
 
-     load_unity_unit_tokenizer,
 
- )
 
- from torch.nn import Module
 
- from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import PretsselGenerator
 
- from utils import LANGUAGE_CODE_TO_NAME
 
- DESCRIPTION = """\
 
- # Seamless Expressive
 
- [SeamlessExpressive](https://github.com/facebookresearch/seamless_communication) is a speech-to-speech translation model that captures certain underexplored aspects of prosody such as speech rate and pauses, while preserving the style of one's voice and high content translation quality.
 
- """
 
- CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
 
- CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
 
- if not CHECKPOINTS_PATH.exists():
 
-     snapshot_download(repo_id="facebook/seamless-expressive", repo_type="model", local_dir=CHECKPOINTS_PATH)
 
-     snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
 
- # Ensure that we do not have any other environment resolvers and always return
 
- # "demo" for demo purposes.
 
- asset_store.env_resolvers.clear()
 
- asset_store.env_resolvers.append(lambda: "demo")
 
- # Construct an `InProcAssetMetadataProvider` with environment-specific metadata
 
- # that just overrides the regular metadata for "demo" environment. Note the "@demo" suffix.
 
- demo_metadata = [
 
-     {
 
-         "name": "seamless_expressivity@demo",
 
-         "checkpoint": f"file://{CHECKPOINTS_PATH}/m2m_expressive_unity.pt",
 
-         "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
 
-     },
 
-     {
 
-         "name": "vocoder_pretssel@demo",
 
-         "checkpoint": f"file://{CHECKPOINTS_PATH}/pretssel_melhifigan_wm-final.pt",
 
-     },
 
-     {
 
-         "name": "seamlessM4T_v2_large@demo",
 
-         "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
 
-         "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
 
-     },
 
- ]
 
- asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
 
- LANGUAGE_NAME_TO_CODE = {v: k for k, v in LANGUAGE_CODE_TO_NAME.items()}
 
- if torch.cuda.is_available():
 
-     device = torch.device("cuda:0")
 
-     dtype = torch.float16
 
- else:
 
-     device = torch.device("cpu")
 
-     dtype = torch.float32
 
- MODEL_NAME = "seamless_expressivity"
 
- VOCODER_NAME = "vocoder_pretssel"
 
- # used for ASR for toxicity
 
- m4t_translator = Translator(
 
-     model_name_or_card="seamlessM4T_v2_large",
 
-     vocoder_name_or_card=None,
 
-     device=device,
 
-     dtype=dtype,
 
- )
 
- unit_tokenizer = load_unity_unit_tokenizer(MODEL_NAME)
 
- _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(VOCODER_NAME)
 
- gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
 
- gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
 
- translator = Translator(
 
-     MODEL_NAME,
 
-     vocoder_name_or_card=None,
 
-     device=device,
 
-     dtype=dtype,
 
-     apply_mintox=False,
 
- )
 
- text_generation_opts = SequenceGeneratorOptions(
 
-     beam_size=5,
 
-     unk_penalty=torch.inf,
 
-     soft_max_seq_len=(0, 200),
 
-     step_processor=NGramRepeatBlockProcessor(
 
-         ngram_size=10,
 
-     ),
 
- )
 
- m4t_text_generation_opts = SequenceGeneratorOptions(
 
-     beam_size=5,
 
-     unk_penalty=torch.inf,
 
-     soft_max_seq_len=(1, 200),
 
-     step_processor=NGramRepeatBlockProcessor(
 
-         ngram_size=10,
 
-     ),
 
- )
 
- pretssel_generator = PretsselGenerator(
 
-     VOCODER_NAME,
 
-     vocab_info=unit_tokenizer.vocab_info,
 
-     device=device,
 
-     dtype=dtype,
 
- )
 
- decode_audio = AudioDecoder(dtype=torch.float32, device=device)
 
- convert_to_fbank = WaveformToFbankConverter(
 
-     num_mel_bins=80,
 
-     waveform_scale=2**15,
 
-     channel_last=True,
 
-     standardize=False,
 
-     device=device,
 
-     dtype=dtype,
 
- )
 
- def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
 
-     fbank = data["fbank"]
 
-     std, mean = torch.std_mean(fbank, dim=0)
 
-     data["fbank"] = fbank.subtract(mean).divide(std)
 
-     data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
 
-     return data
 
- collate = Collater(pad_value=0, pad_to_multiple=1)
 
- AUDIO_SAMPLE_RATE = 16000
 
- MAX_INPUT_AUDIO_LENGTH = 10  # in seconds
 
- def remove_prosody_tokens_from_text(text):
 
-     # filter out prosody tokens, there is only emphasis '*', and pause '='
 
-     text = text.replace("*", "").replace("=", "")
 
-     text = " ".join(text.split())
 
-     return text
 
- def preprocess_audio(input_audio_path: str) -> None:
 
-     arr, org_sr = torchaudio.load(input_audio_path)
 
-     new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
 
-     max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
 
-     if new_arr.shape[1] > max_length:
 
-         new_arr = new_arr[:, :max_length]
 
-         gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
 
-     torchaudio.save(input_audio_path, new_arr, sample_rate=AUDIO_SAMPLE_RATE)
 
- def run(
 
-     input_audio_path: str,
 
-     source_language: str,
 
-     target_language: str,
 
- ) -> tuple[str, str]:
 
-     target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
 
-     source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
 
-     preprocess_audio(input_audio_path)
 
-     with pathlib.Path(input_audio_path).open("rb") as fb:
 
-         block = MemoryBlock(fb.read())
 
-         example = decode_audio(block)
 
-     example = convert_to_fbank(example)
 
-     example = normalize_fbank(example)
 
-     example = collate(example)
 
-     # get transcription for mintox
 
-     source_sentences, _ = m4t_translator.predict(
 
-         input=example["fbank"],
 
-         task_str="S2TT",  # get source text
 
-         tgt_lang=source_language_code,
 
-         text_generation_opts=m4t_text_generation_opts,
 
-     )
 
-     source_text = str(source_sentences[0])
 
-     prosody_encoder_input = example["gcmvn_fbank"]
 
-     text_output, unit_output = translator.predict(
 
-         example["fbank"],
 
-         "S2ST",
 
-         tgt_lang=target_language_code,
 
-         src_lang=source_language_code,
 
-         text_generation_opts=text_generation_opts,
 
-         unit_generation_ngram_filtering=False,
 
-         duration_factor=1.0,
 
-         prosody_encoder_input=prosody_encoder_input,
 
-         src_text=source_text,  # for mintox check
 
-     )
 
-     speech_output = pretssel_generator.predict(
 
-         unit_output.units,
 
-         tgt_lang=target_language_code,
 
-         prosody_encoder_input=prosody_encoder_input,
 
-     )
 
-     with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
 
-         torchaudio.save(
 
-             f.name,
 
-             speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
 
-             sample_rate=speech_output.sample_rate,
 
-         )
 
-     text_out = remove_prosody_tokens_from_text(str(text_output[0]))
 
-     return f.name, text_out
 
- TARGET_LANGUAGE_NAMES = [
 
-     "English",
 
-     "French",
 
-     "German",
 
-     "Spanish",
 
- ]
 
- with gr.Blocks(css="style.css") as demo:
 
-     gr.Markdown(DESCRIPTION)
 
-     gr.DuplicateButton(
 
-         value="Duplicate Space for private use",
 
-         elem_id="duplicate-button",
 
-         visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
 
-     )
 
-     with gr.Row():
 
-         with gr.Column():
 
-             with gr.Group():
 
-                 input_audio = gr.Audio(label="Input speech", type="filepath")
 
-                 source_language = gr.Dropdown(
 
-                     label="Source language",
 
-                     choices=TARGET_LANGUAGE_NAMES,
 
-                     value="English",
 
-                 )
 
-                 target_language = gr.Dropdown(
 
-                     label="Target language",
 
-                     choices=TARGET_LANGUAGE_NAMES,
 
-                     value="French",
 
-                 )
 
-             btn = gr.Button()
 
-         with gr.Column():
 
-             with gr.Group():
 
-                 output_audio = gr.Audio(label="Translated speech")
 
-                 output_text = gr.Textbox(label="Translated text")
 
-     gr.Examples(
 
-         examples=[],
 
-         inputs=[input_audio, source_language, target_language],
 
-         outputs=[output_audio, output_text],
 
-         fn=run,
 
-         cache_examples=CACHE_EXAMPLES,
 
-         api_name=False,
 
-     )
 
-     btn.click(
 
-         fn=run,
 
-         inputs=[input_audio, source_language, target_language],
 
-         outputs=[output_audio, output_text],
 
-         api_name="run",
 
-     )
 
- if __name__ == "__main__":
 
-     demo.queue(max_size=50).launch()
 
 
  |