123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714 |
- from __future__ import annotations
- import gradio as gr
- import numpy as np
- import torch
- import torchaudio
- from huggingface_hub import hf_hub_download
- from seamless_communication.models.inference.translator import Translator
- DESCRIPTION = """# SeamlessM4T
- [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
- translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
- This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
- translation and more, without relying on multiple separate models.
- """
- TASK_NAMES = [
- "S2ST (Speech to Speech translation)",
- "S2TT (Speech to Text translation)",
- "T2ST (Text to Speech translation)",
- "T2TT (Text to Text translation)",
- "ASR (Automatic Speech Recognition)",
- ]
- # Language dict
- language_code_to_name = {
- "afr": "Afrikaans",
- "amh": "Amharic",
- "arb": "Modern Standard Arabic",
- "ary": "Moroccan Arabic",
- "arz": "Egyptian Arabic",
- "asm": "Assamese",
- "ast": "Asturian",
- "azj": "North Azerbaijani",
- "bel": "Belarusian",
- "ben": "Bengali",
- "bos": "Bosnian",
- "bul": "Bulgarian",
- "cat": "Catalan",
- "ceb": "Cebuano",
- "ces": "Czech",
- "ckb": "Central Kurdish",
- "cmn": "Mandarin Chinese",
- "cym": "Welsh",
- "dan": "Danish",
- "deu": "German",
- "ell": "Greek",
- "eng": "English",
- "est": "Estonian",
- "eus": "Basque",
- "fin": "Finnish",
- "fra": "French",
- "gaz": "West Central Oromo",
- "gle": "Irish",
- "glg": "Galician",
- "guj": "Gujarati",
- "heb": "Hebrew",
- "hin": "Hindi",
- "hrv": "Croatian",
- "hun": "Hungarian",
- "hye": "Armenian",
- "ibo": "Igbo",
- "ind": "Indonesian",
- "isl": "Icelandic",
- "ita": "Italian",
- "jav": "Javanese",
- "jpn": "Japanese",
- "kam": "Kamba",
- "kan": "Kannada",
- "kat": "Georgian",
- "kaz": "Kazakh",
- "kea": "Kabuverdianu",
- "khk": "Halh Mongolian",
- "khm": "Khmer",
- "kir": "Kyrgyz",
- "kor": "Korean",
- "lao": "Lao",
- "lit": "Lithuanian",
- "ltz": "Luxembourgish",
- "lug": "Ganda",
- "luo": "Luo",
- "lvs": "Standard Latvian",
- "mai": "Maithili",
- "mal": "Malayalam",
- "mar": "Marathi",
- "mkd": "Macedonian",
- "mlt": "Maltese",
- "mni": "Meitei",
- "mya": "Burmese",
- "nld": "Dutch",
- "nno": "Norwegian Nynorsk",
- "nob": "Norwegian Bokm\u00e5l",
- "npi": "Nepali",
- "nya": "Nyanja",
- "oci": "Occitan",
- "ory": "Odia",
- "pan": "Punjabi",
- "pbt": "Southern Pashto",
- "pes": "Western Persian",
- "pol": "Polish",
- "por": "Portuguese",
- "ron": "Romanian",
- "rus": "Russian",
- "slk": "Slovak",
- "slv": "Slovenian",
- "sna": "Shona",
- "snd": "Sindhi",
- "som": "Somali",
- "spa": "Spanish",
- "srp": "Serbian",
- "swe": "Swedish",
- "swh": "Swahili",
- "tam": "Tamil",
- "tel": "Telugu",
- "tgk": "Tajik",
- "tgl": "Tagalog",
- "tha": "Thai",
- "tur": "Turkish",
- "ukr": "Ukrainian",
- "urd": "Urdu",
- "uzn": "Northern Uzbek",
- "vie": "Vietnamese",
- "xho": "Xhosa",
- "yor": "Yoruba",
- "yue": "Cantonese",
- "zlm": "Colloquial Malay",
- "zsm": "Standard Malay",
- "zul": "Zulu",
- }
- LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
- # Source langs: S2ST / S2TT / ASR don't need source lang
- # T2TT / T2ST use this
- text_source_language_codes = [
- "afr",
- "amh",
- "arb",
- "ary",
- "arz",
- "asm",
- "azj",
- "bel",
- "ben",
- "bos",
- "bul",
- "cat",
- "ceb",
- "ces",
- "ckb",
- "cmn",
- "cym",
- "dan",
- "deu",
- "ell",
- "eng",
- "est",
- "eus",
- "fin",
- "fra",
- "gaz",
- "gle",
- "glg",
- "guj",
- "heb",
- "hin",
- "hrv",
- "hun",
- "hye",
- "ibo",
- "ind",
- "isl",
- "ita",
- "jav",
- "jpn",
- "kan",
- "kat",
- "kaz",
- "khk",
- "khm",
- "kir",
- "kor",
- "lao",
- "lit",
- "lug",
- "luo",
- "lvs",
- "mai",
- "mal",
- "mar",
- "mkd",
- "mlt",
- "mni",
- "mya",
- "nld",
- "nno",
- "nob",
- "npi",
- "nya",
- "ory",
- "pan",
- "pbt",
- "pes",
- "pol",
- "por",
- "ron",
- "rus",
- "slk",
- "slv",
- "sna",
- "snd",
- "som",
- "spa",
- "srp",
- "swe",
- "swh",
- "tam",
- "tel",
- "tgk",
- "tgl",
- "tha",
- "tur",
- "ukr",
- "urd",
- "uzn",
- "vie",
- "yor",
- "yue",
- "zsm",
- "zul",
- ]
- TEXT_SOURCE_LANGUAGE_NAMES = sorted(
- [language_code_to_name[code] for code in text_source_language_codes]
- )
- # Target langs:
- # S2ST / T2ST
- s2st_target_language_codes = [
- "eng",
- "arb",
- "ben",
- "cat",
- "ces",
- "cmn",
- "cym",
- "dan",
- "deu",
- "est",
- "fin",
- "fra",
- "hin",
- "ind",
- "ita",
- "jpn",
- "kor",
- "mlt",
- "nld",
- "pes",
- "pol",
- "por",
- "ron",
- "rus",
- "slk",
- "spa",
- "swe",
- "swh",
- "tel",
- "tgl",
- "tha",
- "tur",
- "ukr",
- "urd",
- "uzn",
- "vie",
- ]
- S2ST_TARGET_LANGUAGE_NAMES = sorted(
- [language_code_to_name[code] for code in s2st_target_language_codes]
- )
- # S2TT / ASR
- S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
- # T2TT
- T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
- # Download sample input audio files
- filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
- for filename in filenames:
- hf_hub_download(
- repo_id="facebook/seamless_m4t",
- repo_type="space",
- filename=filename,
- local_dir=".",
- )
- AUDIO_SAMPLE_RATE = 16000.0
- MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
- DEFAULT_TARGET_LANGUAGE = "French"
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- translator = Translator(
- model_name_or_card="seamlessM4T_large",
- vocoder_name_or_card="vocoder_36langs",
- device=device,
- dtype=torch.float16 if "cuda" in device.type else torch.float32,
- )
- def predict(
- task_name: str,
- audio_source: str,
- input_audio_mic: str | None,
- input_audio_file: str | None,
- input_text: str | None,
- source_language: str | None,
- target_language: str,
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
- task_name = task_name.split()[0]
- source_language_code = (
- LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
- )
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
- if task_name in ["S2ST", "S2TT", "ASR"]:
- if audio_source == "microphone":
- input_data = input_audio_mic
- else:
- input_data = input_audio_file
- arr, org_sr = torchaudio.load(input_data)
- new_arr = torchaudio.functional.resample(
- arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE
- )
- max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
- if new_arr.shape[1] > max_length:
- new_arr = new_arr[:, :max_length]
- gr.Warning(
- f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used."
- )
- torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
- else:
- input_data = input_text
- text_out, wav, sr = translator.predict(
- input=input_data,
- task_str=task_name,
- tgt_lang=target_language_code,
- src_lang=source_language_code,
- ngram_filtering=True,
- )
- if task_name in ["S2ST", "T2ST"]:
- return (sr, wav.cpu().detach().numpy()), text_out
- else:
- return None, text_out
- def process_s2st_example(
- input_audio_file: str, target_language: str
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
- return predict(
- task_name="S2ST",
- audio_source="file",
- input_audio_mic=None,
- input_audio_file=input_audio_file,
- input_text=None,
- source_language=None,
- target_language=target_language,
- )
- def process_s2tt_example(
- input_audio_file: str, target_language: str
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
- return predict(
- task_name="S2TT",
- audio_source="file",
- input_audio_mic=None,
- input_audio_file=input_audio_file,
- input_text=None,
- source_language=None,
- target_language=target_language,
- )
- def process_t2st_example(
- input_text: str, source_language: str, target_language: str
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
- return predict(
- task_name="T2ST",
- audio_source="",
- input_audio_mic=None,
- input_audio_file=None,
- input_text=input_text,
- source_language=source_language,
- target_language=target_language,
- )
- def process_t2tt_example(
- input_text: str, source_language: str, target_language: str
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
- return predict(
- task_name="T2TT",
- audio_source="",
- input_audio_mic=None,
- input_audio_file=None,
- input_text=input_text,
- source_language=source_language,
- target_language=target_language,
- )
- def process_asr_example(
- input_audio_file: str, target_language: str
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
- return predict(
- task_name="ASR",
- audio_source="file",
- input_audio_mic=None,
- input_audio_file=input_audio_file,
- input_text=None,
- source_language=None,
- target_language=target_language,
- )
- def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
- mic = audio_source == "microphone"
- return (
- gr.update(visible=mic, value=None), # input_audio_mic
- gr.update(visible=not mic, value=None), # input_audio_file
- )
- def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
- task_name = task_name.split()[0]
- if task_name == "S2ST":
- return (
- gr.update(visible=True), # audio_box
- gr.update(visible=False), # input_text
- gr.update(visible=False), # source_language
- gr.update(
- visible=True,
- choices=S2ST_TARGET_LANGUAGE_NAMES,
- value=DEFAULT_TARGET_LANGUAGE,
- ), # target_language
- )
- elif task_name == "S2TT":
- return (
- gr.update(visible=True), # audio_box
- gr.update(visible=False), # input_text
- gr.update(visible=False), # source_language
- gr.update(
- visible=True,
- choices=S2TT_TARGET_LANGUAGE_NAMES,
- value=DEFAULT_TARGET_LANGUAGE,
- ), # target_language
- )
- elif task_name == "T2ST":
- return (
- gr.update(visible=False), # audio_box
- gr.update(visible=True), # input_text
- gr.update(visible=True), # source_language
- gr.update(
- visible=True,
- choices=S2ST_TARGET_LANGUAGE_NAMES,
- value=DEFAULT_TARGET_LANGUAGE,
- ), # target_language
- )
- elif task_name == "T2TT":
- return (
- gr.update(visible=False), # audio_box
- gr.update(visible=True), # input_text
- gr.update(visible=True), # source_language
- gr.update(
- visible=True,
- choices=T2TT_TARGET_LANGUAGE_NAMES,
- value=DEFAULT_TARGET_LANGUAGE,
- ), # target_language
- )
- elif task_name == "ASR":
- return (
- gr.update(visible=True), # audio_box
- gr.update(visible=False), # input_text
- gr.update(visible=False), # source_language
- gr.update(
- visible=True,
- choices=S2TT_TARGET_LANGUAGE_NAMES,
- value=DEFAULT_TARGET_LANGUAGE,
- ), # target_language
- )
- else:
- raise ValueError(f"Unknown task: {task_name}")
- def update_output_ui(task_name: str) -> tuple[dict, dict]:
- task_name = task_name.split()[0]
- if task_name in ["S2ST", "T2ST"]:
- return (
- gr.update(visible=True, value=None), # output_audio
- gr.update(value=None), # output_text
- )
- elif task_name in ["S2TT", "T2TT", "ASR"]:
- return (
- gr.update(visible=False, value=None), # output_audio
- gr.update(value=None), # output_text
- )
- else:
- raise ValueError(f"Unknown task: {task_name}")
- def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
- task_name = task_name.split()[0]
- return (
- gr.update(visible=task_name == "S2ST"), # s2st_example_row
- gr.update(visible=task_name == "S2TT"), # s2tt_example_row
- gr.update(visible=task_name == "T2ST"), # t2st_example_row
- gr.update(visible=task_name == "T2TT"), # t2tt_example_row
- gr.update(visible=task_name == "ASR"), # asr_example_row
- )
- css = """
- h1 {
- text-align: center;
- }
- .contain {
- max-width: 730px;
- margin: auto;
- padding-top: 1.5rem;
- }
- """
- with gr.Blocks(css=css) as demo:
- gr.Markdown(DESCRIPTION)
- with gr.Group():
- task_name = gr.Dropdown(
- label="Task",
- choices=TASK_NAMES,
- value=TASK_NAMES[0],
- )
- with gr.Row():
- source_language = gr.Dropdown(
- label="Source language",
- choices=TEXT_SOURCE_LANGUAGE_NAMES,
- value="English",
- visible=False,
- )
- target_language = gr.Dropdown(
- label="Target language",
- choices=S2ST_TARGET_LANGUAGE_NAMES,
- value=DEFAULT_TARGET_LANGUAGE,
- )
- with gr.Row() as audio_box:
- audio_source = gr.Radio(
- label="Audio source",
- choices=["file", "microphone"],
- value="file",
- )
- input_audio_mic = gr.Audio(
- label="Input speech",
- type="filepath",
- source="microphone",
- visible=False,
- )
- input_audio_file = gr.Audio(
- label="Input speech",
- type="filepath",
- source="upload",
- visible=True,
- )
- input_text = gr.Textbox(label="Input text", visible=False)
- btn = gr.Button("Translate")
- with gr.Column():
- output_audio = gr.Audio(
- label="Translated speech",
- autoplay=False,
- streaming=False,
- type="numpy",
- )
- output_text = gr.Textbox(label="Translated text")
- with gr.Row(visible=True) as s2st_example_row:
- s2st_examples = gr.Examples(
- examples=[
- ["assets/sample_input.mp3", "French"],
- ["assets/sample_input.mp3", "Mandarin Chinese"],
- ["assets/sample_input_2.mp3", "Hindi"],
- ["assets/sample_input_2.mp3", "Spanish"],
- ],
- inputs=[input_audio_file, target_language],
- outputs=[output_audio, output_text],
- fn=process_s2st_example,
- )
- with gr.Row(visible=False) as s2tt_example_row:
- s2tt_examples = gr.Examples(
- examples=[
- ["assets/sample_input.mp3", "French"],
- ["assets/sample_input.mp3", "Mandarin Chinese"],
- ["assets/sample_input_2.mp3", "Hindi"],
- ["assets/sample_input_2.mp3", "Spanish"],
- ],
- inputs=[input_audio_file, target_language],
- outputs=[output_audio, output_text],
- fn=process_s2tt_example,
- )
- with gr.Row(visible=False) as t2st_example_row:
- t2st_examples = gr.Examples(
- examples=[
- ["My favorite animal is the elephant.", "English", "French"],
- ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
- [
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
- "English",
- "Hindi",
- ],
- [
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
- "English",
- "Spanish",
- ],
- ],
- inputs=[input_text, source_language, target_language],
- outputs=[output_audio, output_text],
- fn=process_t2st_example,
- )
- with gr.Row(visible=False) as t2tt_example_row:
- t2tt_examples = gr.Examples(
- examples=[
- ["My favorite animal is the elephant.", "English", "French"],
- ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
- [
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
- "English",
- "Hindi",
- ],
- [
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
- "English",
- "Spanish",
- ],
- ],
- inputs=[input_text, source_language, target_language],
- outputs=[output_audio, output_text],
- fn=process_t2tt_example,
- )
- with gr.Row(visible=False) as asr_example_row:
- asr_examples = gr.Examples(
- examples=[
- ["assets/sample_input.mp3", "English"],
- ["assets/sample_input_2.mp3", "English"],
- ],
- inputs=[input_audio_file, target_language],
- outputs=[output_audio, output_text],
- fn=process_asr_example,
- )
- audio_source.change(
- fn=update_audio_ui,
- inputs=audio_source,
- outputs=[
- input_audio_mic,
- input_audio_file,
- ],
- queue=False,
- api_name=False,
- )
- task_name.change(
- fn=update_input_ui,
- inputs=task_name,
- outputs=[
- audio_box,
- input_text,
- source_language,
- target_language,
- ],
- queue=False,
- api_name=False,
- ).then(
- fn=update_output_ui,
- inputs=task_name,
- outputs=[output_audio, output_text],
- queue=False,
- api_name=False,
- ).then(
- fn=update_example_ui,
- inputs=task_name,
- outputs=[
- s2st_example_row,
- s2tt_example_row,
- t2st_example_row,
- t2tt_example_row,
- asr_example_row,
- ],
- queue=False,
- api_name=False,
- )
- btn.click(
- fn=predict,
- inputs=[
- task_name,
- audio_source,
- input_audio_mic,
- input_audio_file,
- input_text,
- source_language,
- target_language,
- ],
- outputs=[output_audio, output_text],
- api_name="run",
- )
- if __name__ == "__main__":
- demo.queue().launch()
|