app.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. import gradio as gr
  8. import numpy as np
  9. import torch
  10. import torchaudio
  11. from huggingface_hub import hf_hub_download
  12. from seamless_communication.models.inference.translator import Translator
  13. DESCRIPTION = """# SeamlessM4T
  14. [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
  15. translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
  16. This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
  17. translation and more, without relying on multiple separate models.
  18. """
  19. TASK_NAMES = [
  20. "S2ST (Speech to Speech translation)",
  21. "S2TT (Speech to Text translation)",
  22. "T2ST (Text to Speech translation)",
  23. "T2TT (Text to Text translation)",
  24. "ASR (Automatic Speech Recognition)",
  25. ]
  26. # Language dict
  27. language_code_to_name = {
  28. "afr": "Afrikaans",
  29. "amh": "Amharic",
  30. "arb": "Modern Standard Arabic",
  31. "ary": "Moroccan Arabic",
  32. "arz": "Egyptian Arabic",
  33. "asm": "Assamese",
  34. "ast": "Asturian",
  35. "azj": "North Azerbaijani",
  36. "bel": "Belarusian",
  37. "ben": "Bengali",
  38. "bos": "Bosnian",
  39. "bul": "Bulgarian",
  40. "cat": "Catalan",
  41. "ceb": "Cebuano",
  42. "ces": "Czech",
  43. "ckb": "Central Kurdish",
  44. "cmn": "Mandarin Chinese",
  45. "cym": "Welsh",
  46. "dan": "Danish",
  47. "deu": "German",
  48. "ell": "Greek",
  49. "eng": "English",
  50. "est": "Estonian",
  51. "eus": "Basque",
  52. "fin": "Finnish",
  53. "fra": "French",
  54. "gaz": "West Central Oromo",
  55. "gle": "Irish",
  56. "glg": "Galician",
  57. "guj": "Gujarati",
  58. "heb": "Hebrew",
  59. "hin": "Hindi",
  60. "hrv": "Croatian",
  61. "hun": "Hungarian",
  62. "hye": "Armenian",
  63. "ibo": "Igbo",
  64. "ind": "Indonesian",
  65. "isl": "Icelandic",
  66. "ita": "Italian",
  67. "jav": "Javanese",
  68. "jpn": "Japanese",
  69. "kam": "Kamba",
  70. "kan": "Kannada",
  71. "kat": "Georgian",
  72. "kaz": "Kazakh",
  73. "kea": "Kabuverdianu",
  74. "khk": "Halh Mongolian",
  75. "khm": "Khmer",
  76. "kir": "Kyrgyz",
  77. "kor": "Korean",
  78. "lao": "Lao",
  79. "lit": "Lithuanian",
  80. "ltz": "Luxembourgish",
  81. "lug": "Ganda",
  82. "luo": "Luo",
  83. "lvs": "Standard Latvian",
  84. "mai": "Maithili",
  85. "mal": "Malayalam",
  86. "mar": "Marathi",
  87. "mkd": "Macedonian",
  88. "mlt": "Maltese",
  89. "mni": "Meitei",
  90. "mya": "Burmese",
  91. "nld": "Dutch",
  92. "nno": "Norwegian Nynorsk",
  93. "nob": "Norwegian Bokm\u00e5l",
  94. "npi": "Nepali",
  95. "nya": "Nyanja",
  96. "oci": "Occitan",
  97. "ory": "Odia",
  98. "pan": "Punjabi",
  99. "pbt": "Southern Pashto",
  100. "pes": "Western Persian",
  101. "pol": "Polish",
  102. "por": "Portuguese",
  103. "ron": "Romanian",
  104. "rus": "Russian",
  105. "slk": "Slovak",
  106. "slv": "Slovenian",
  107. "sna": "Shona",
  108. "snd": "Sindhi",
  109. "som": "Somali",
  110. "spa": "Spanish",
  111. "srp": "Serbian",
  112. "swe": "Swedish",
  113. "swh": "Swahili",
  114. "tam": "Tamil",
  115. "tel": "Telugu",
  116. "tgk": "Tajik",
  117. "tgl": "Tagalog",
  118. "tha": "Thai",
  119. "tur": "Turkish",
  120. "ukr": "Ukrainian",
  121. "urd": "Urdu",
  122. "uzn": "Northern Uzbek",
  123. "vie": "Vietnamese",
  124. "xho": "Xhosa",
  125. "yor": "Yoruba",
  126. "yue": "Cantonese",
  127. "zlm": "Colloquial Malay",
  128. "zsm": "Standard Malay",
  129. "zul": "Zulu",
  130. }
  131. LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
  132. # Source langs: S2ST / S2TT / ASR don't need source lang
  133. # T2TT / T2ST use this
  134. text_source_language_codes = [
  135. "afr",
  136. "amh",
  137. "arb",
  138. "ary",
  139. "arz",
  140. "asm",
  141. "azj",
  142. "bel",
  143. "ben",
  144. "bos",
  145. "bul",
  146. "cat",
  147. "ceb",
  148. "ces",
  149. "ckb",
  150. "cmn",
  151. "cym",
  152. "dan",
  153. "deu",
  154. "ell",
  155. "eng",
  156. "est",
  157. "eus",
  158. "fin",
  159. "fra",
  160. "gaz",
  161. "gle",
  162. "glg",
  163. "guj",
  164. "heb",
  165. "hin",
  166. "hrv",
  167. "hun",
  168. "hye",
  169. "ibo",
  170. "ind",
  171. "isl",
  172. "ita",
  173. "jav",
  174. "jpn",
  175. "kan",
  176. "kat",
  177. "kaz",
  178. "khk",
  179. "khm",
  180. "kir",
  181. "kor",
  182. "lao",
  183. "lit",
  184. "lug",
  185. "luo",
  186. "lvs",
  187. "mai",
  188. "mal",
  189. "mar",
  190. "mkd",
  191. "mlt",
  192. "mni",
  193. "mya",
  194. "nld",
  195. "nno",
  196. "nob",
  197. "npi",
  198. "nya",
  199. "ory",
  200. "pan",
  201. "pbt",
  202. "pes",
  203. "pol",
  204. "por",
  205. "ron",
  206. "rus",
  207. "slk",
  208. "slv",
  209. "sna",
  210. "snd",
  211. "som",
  212. "spa",
  213. "srp",
  214. "swe",
  215. "swh",
  216. "tam",
  217. "tel",
  218. "tgk",
  219. "tgl",
  220. "tha",
  221. "tur",
  222. "ukr",
  223. "urd",
  224. "uzn",
  225. "vie",
  226. "yor",
  227. "yue",
  228. "zsm",
  229. "zul",
  230. ]
  231. TEXT_SOURCE_LANGUAGE_NAMES = sorted(
  232. [language_code_to_name[code] for code in text_source_language_codes]
  233. )
  234. # Target langs:
  235. # S2ST / T2ST
  236. s2st_target_language_codes = [
  237. "eng",
  238. "arb",
  239. "ben",
  240. "cat",
  241. "ces",
  242. "cmn",
  243. "cym",
  244. "dan",
  245. "deu",
  246. "est",
  247. "fin",
  248. "fra",
  249. "hin",
  250. "ind",
  251. "ita",
  252. "jpn",
  253. "kor",
  254. "mlt",
  255. "nld",
  256. "pes",
  257. "pol",
  258. "por",
  259. "ron",
  260. "rus",
  261. "slk",
  262. "spa",
  263. "swe",
  264. "swh",
  265. "tel",
  266. "tgl",
  267. "tha",
  268. "tur",
  269. "ukr",
  270. "urd",
  271. "uzn",
  272. "vie",
  273. ]
  274. S2ST_TARGET_LANGUAGE_NAMES = sorted(
  275. [language_code_to_name[code] for code in s2st_target_language_codes]
  276. )
  277. # S2TT / ASR
  278. S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
  279. # T2TT
  280. T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
  281. # Download sample input audio files
  282. filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
  283. for filename in filenames:
  284. hf_hub_download(
  285. repo_id="facebook/seamless_m4t",
  286. repo_type="space",
  287. filename=filename,
  288. local_dir=".",
  289. )
  290. AUDIO_SAMPLE_RATE = 16000.0
  291. MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
  292. DEFAULT_TARGET_LANGUAGE = "French"
  293. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  294. translator = Translator(
  295. model_name_or_card="seamlessM4T_large",
  296. vocoder_name_or_card="vocoder_36langs",
  297. device=device,
  298. dtype=torch.float16 if "cuda" in device.type else torch.float32,
  299. )
  300. def predict(
  301. task_name: str,
  302. audio_source: str,
  303. input_audio_mic: str | None,
  304. input_audio_file: str | None,
  305. input_text: str | None,
  306. source_language: str | None,
  307. target_language: str,
  308. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  309. task_name = task_name.split()[0]
  310. source_language_code = (
  311. LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
  312. )
  313. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  314. if task_name in ["S2ST", "S2TT", "ASR"]:
  315. if audio_source == "microphone":
  316. input_data = input_audio_mic
  317. else:
  318. input_data = input_audio_file
  319. arr, org_sr = torchaudio.load(input_data)
  320. new_arr = torchaudio.functional.resample(
  321. arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE
  322. )
  323. max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
  324. if new_arr.shape[1] > max_length:
  325. new_arr = new_arr[:, :max_length]
  326. gr.Warning(
  327. f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used."
  328. )
  329. torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
  330. else:
  331. input_data = input_text
  332. assert input_data is not None
  333. text_output, speech_output = translator.predict(
  334. input_data,
  335. task_name,
  336. target_language_code,
  337. src_lang=source_language_code,
  338. unit_generation_ngram_filtering=True,
  339. )
  340. if task_name in ["S2ST", "T2ST"]:
  341. assert speech_output is not None
  342. return (
  343. speech_output.sample_rate,
  344. speech_output.audio_wavs[0].cpu().detach().numpy(),
  345. ), str(text_output[0])
  346. else:
  347. return None, str(text_output[0])
  348. def process_s2st_example(
  349. input_audio_file: str, target_language: str
  350. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  351. return predict(
  352. task_name="S2ST",
  353. audio_source="file",
  354. input_audio_mic=None,
  355. input_audio_file=input_audio_file,
  356. input_text=None,
  357. source_language=None,
  358. target_language=target_language,
  359. )
  360. def process_s2tt_example(
  361. input_audio_file: str, target_language: str
  362. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  363. return predict(
  364. task_name="S2TT",
  365. audio_source="file",
  366. input_audio_mic=None,
  367. input_audio_file=input_audio_file,
  368. input_text=None,
  369. source_language=None,
  370. target_language=target_language,
  371. )
  372. def process_t2st_example(
  373. input_text: str, source_language: str, target_language: str
  374. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  375. return predict(
  376. task_name="T2ST",
  377. audio_source="",
  378. input_audio_mic=None,
  379. input_audio_file=None,
  380. input_text=input_text,
  381. source_language=source_language,
  382. target_language=target_language,
  383. )
  384. def process_t2tt_example(
  385. input_text: str, source_language: str, target_language: str
  386. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  387. return predict(
  388. task_name="T2TT",
  389. audio_source="",
  390. input_audio_mic=None,
  391. input_audio_file=None,
  392. input_text=input_text,
  393. source_language=source_language,
  394. target_language=target_language,
  395. )
  396. def process_asr_example(
  397. input_audio_file: str, target_language: str
  398. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  399. return predict(
  400. task_name="ASR",
  401. audio_source="file",
  402. input_audio_mic=None,
  403. input_audio_file=input_audio_file,
  404. input_text=None,
  405. source_language=None,
  406. target_language=target_language,
  407. )
  408. def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
  409. mic = audio_source == "microphone"
  410. return (
  411. gr.update(visible=mic, value=None), # input_audio_mic
  412. gr.update(visible=not mic, value=None), # input_audio_file
  413. )
  414. def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
  415. task_name = task_name.split()[0]
  416. if task_name == "S2ST":
  417. return (
  418. gr.update(visible=True), # audio_box
  419. gr.update(visible=False), # input_text
  420. gr.update(visible=False), # source_language
  421. gr.update(
  422. visible=True,
  423. choices=S2ST_TARGET_LANGUAGE_NAMES,
  424. value=DEFAULT_TARGET_LANGUAGE,
  425. ), # target_language
  426. )
  427. elif task_name == "S2TT":
  428. return (
  429. gr.update(visible=True), # audio_box
  430. gr.update(visible=False), # input_text
  431. gr.update(visible=False), # source_language
  432. gr.update(
  433. visible=True,
  434. choices=S2TT_TARGET_LANGUAGE_NAMES,
  435. value=DEFAULT_TARGET_LANGUAGE,
  436. ), # target_language
  437. )
  438. elif task_name == "T2ST":
  439. return (
  440. gr.update(visible=False), # audio_box
  441. gr.update(visible=True), # input_text
  442. gr.update(visible=True), # source_language
  443. gr.update(
  444. visible=True,
  445. choices=S2ST_TARGET_LANGUAGE_NAMES,
  446. value=DEFAULT_TARGET_LANGUAGE,
  447. ), # target_language
  448. )
  449. elif task_name == "T2TT":
  450. return (
  451. gr.update(visible=False), # audio_box
  452. gr.update(visible=True), # input_text
  453. gr.update(visible=True), # source_language
  454. gr.update(
  455. visible=True,
  456. choices=T2TT_TARGET_LANGUAGE_NAMES,
  457. value=DEFAULT_TARGET_LANGUAGE,
  458. ), # target_language
  459. )
  460. elif task_name == "ASR":
  461. return (
  462. gr.update(visible=True), # audio_box
  463. gr.update(visible=False), # input_text
  464. gr.update(visible=False), # source_language
  465. gr.update(
  466. visible=True,
  467. choices=S2TT_TARGET_LANGUAGE_NAMES,
  468. value=DEFAULT_TARGET_LANGUAGE,
  469. ), # target_language
  470. )
  471. else:
  472. raise ValueError(f"Unknown task: {task_name}")
  473. def update_output_ui(task_name: str) -> tuple[dict, dict]:
  474. task_name = task_name.split()[0]
  475. if task_name in ["S2ST", "T2ST"]:
  476. return (
  477. gr.update(visible=True, value=None), # output_audio
  478. gr.update(value=None), # output_text
  479. )
  480. elif task_name in ["S2TT", "T2TT", "ASR"]:
  481. return (
  482. gr.update(visible=False, value=None), # output_audio
  483. gr.update(value=None), # output_text
  484. )
  485. else:
  486. raise ValueError(f"Unknown task: {task_name}")
  487. def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
  488. task_name = task_name.split()[0]
  489. return (
  490. gr.update(visible=task_name == "S2ST"), # s2st_example_row
  491. gr.update(visible=task_name == "S2TT"), # s2tt_example_row
  492. gr.update(visible=task_name == "T2ST"), # t2st_example_row
  493. gr.update(visible=task_name == "T2TT"), # t2tt_example_row
  494. gr.update(visible=task_name == "ASR"), # asr_example_row
  495. )
  496. css = """
  497. h1 {
  498. text-align: center;
  499. }
  500. .contain {
  501. max-width: 730px;
  502. margin: auto;
  503. padding-top: 1.5rem;
  504. }
  505. """
  506. with gr.Blocks(css=css) as demo:
  507. gr.Markdown(DESCRIPTION)
  508. with gr.Group():
  509. task_name = gr.Dropdown(
  510. label="Task",
  511. choices=TASK_NAMES,
  512. value=TASK_NAMES[0],
  513. )
  514. with gr.Row():
  515. source_language = gr.Dropdown(
  516. label="Source language",
  517. choices=TEXT_SOURCE_LANGUAGE_NAMES,
  518. value="English",
  519. visible=False,
  520. )
  521. target_language = gr.Dropdown(
  522. label="Target language",
  523. choices=S2ST_TARGET_LANGUAGE_NAMES,
  524. value=DEFAULT_TARGET_LANGUAGE,
  525. )
  526. with gr.Row() as audio_box:
  527. audio_source = gr.Radio(
  528. label="Audio source",
  529. choices=["file", "microphone"],
  530. value="file",
  531. )
  532. input_audio_mic = gr.Audio(
  533. label="Input speech",
  534. type="filepath",
  535. source="microphone",
  536. visible=False,
  537. )
  538. input_audio_file = gr.Audio(
  539. label="Input speech",
  540. type="filepath",
  541. source="upload",
  542. visible=True,
  543. )
  544. input_text = gr.Textbox(label="Input text", visible=False)
  545. btn = gr.Button("Translate")
  546. with gr.Column():
  547. output_audio = gr.Audio(
  548. label="Translated speech",
  549. autoplay=False,
  550. streaming=False,
  551. type="numpy",
  552. )
  553. output_text = gr.Textbox(label="Translated text")
  554. with gr.Row(visible=True) as s2st_example_row:
  555. s2st_examples = gr.Examples(
  556. examples=[
  557. ["assets/sample_input.mp3", "French"],
  558. ["assets/sample_input.mp3", "Mandarin Chinese"],
  559. ["assets/sample_input_2.mp3", "Hindi"],
  560. ["assets/sample_input_2.mp3", "Spanish"],
  561. ],
  562. inputs=[input_audio_file, target_language],
  563. outputs=[output_audio, output_text],
  564. fn=process_s2st_example,
  565. )
  566. with gr.Row(visible=False) as s2tt_example_row:
  567. s2tt_examples = gr.Examples(
  568. examples=[
  569. ["assets/sample_input.mp3", "French"],
  570. ["assets/sample_input.mp3", "Mandarin Chinese"],
  571. ["assets/sample_input_2.mp3", "Hindi"],
  572. ["assets/sample_input_2.mp3", "Spanish"],
  573. ],
  574. inputs=[input_audio_file, target_language],
  575. outputs=[output_audio, output_text],
  576. fn=process_s2tt_example,
  577. )
  578. with gr.Row(visible=False) as t2st_example_row:
  579. t2st_examples = gr.Examples(
  580. examples=[
  581. ["My favorite animal is the elephant.", "English", "French"],
  582. ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
  583. [
  584. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  585. "English",
  586. "Hindi",
  587. ],
  588. [
  589. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  590. "English",
  591. "Spanish",
  592. ],
  593. ],
  594. inputs=[input_text, source_language, target_language],
  595. outputs=[output_audio, output_text],
  596. fn=process_t2st_example,
  597. )
  598. with gr.Row(visible=False) as t2tt_example_row:
  599. t2tt_examples = gr.Examples(
  600. examples=[
  601. ["My favorite animal is the elephant.", "English", "French"],
  602. ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
  603. [
  604. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  605. "English",
  606. "Hindi",
  607. ],
  608. [
  609. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  610. "English",
  611. "Spanish",
  612. ],
  613. ],
  614. inputs=[input_text, source_language, target_language],
  615. outputs=[output_audio, output_text],
  616. fn=process_t2tt_example,
  617. )
  618. with gr.Row(visible=False) as asr_example_row:
  619. asr_examples = gr.Examples(
  620. examples=[
  621. ["assets/sample_input.mp3", "English"],
  622. ["assets/sample_input_2.mp3", "English"],
  623. ],
  624. inputs=[input_audio_file, target_language],
  625. outputs=[output_audio, output_text],
  626. fn=process_asr_example,
  627. )
  628. audio_source.change(
  629. fn=update_audio_ui,
  630. inputs=audio_source,
  631. outputs=[
  632. input_audio_mic,
  633. input_audio_file,
  634. ],
  635. queue=False,
  636. api_name=False,
  637. )
  638. task_name.change(
  639. fn=update_input_ui,
  640. inputs=task_name,
  641. outputs=[
  642. audio_box,
  643. input_text,
  644. source_language,
  645. target_language,
  646. ],
  647. queue=False,
  648. api_name=False,
  649. ).then(
  650. fn=update_output_ui,
  651. inputs=task_name,
  652. outputs=[output_audio, output_text],
  653. queue=False,
  654. api_name=False,
  655. ).then(
  656. fn=update_example_ui,
  657. inputs=task_name,
  658. outputs=[
  659. s2st_example_row,
  660. s2tt_example_row,
  661. t2st_example_row,
  662. t2tt_example_row,
  663. asr_example_row,
  664. ],
  665. queue=False,
  666. api_name=False,
  667. )
  668. btn.click(
  669. fn=predict,
  670. inputs=[
  671. task_name,
  672. audio_source,
  673. input_audio_mic,
  674. input_audio_file,
  675. input_text,
  676. source_language,
  677. target_language,
  678. ],
  679. outputs=[output_audio, output_text],
  680. api_name="run",
  681. )
  682. if __name__ == "__main__":
  683. demo.queue().launch()