app.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. import gradio as gr
  8. import numpy as np
  9. import torch
  10. import torchaudio
  11. from huggingface_hub import hf_hub_download
  12. from seamless_communication.models.inference.translator import Translator
  13. DESCRIPTION = """# SeamlessM4T
  14. [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
  15. translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
  16. This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
  17. translation and more, without relying on multiple separate models.
  18. """
  19. TASK_NAMES = [
  20. "S2ST (Speech to Speech translation)",
  21. "S2TT (Speech to Text translation)",
  22. "T2ST (Text to Speech translation)",
  23. "T2TT (Text to Text translation)",
  24. "ASR (Automatic Speech Recognition)",
  25. ]
  26. # Language dict
  27. language_code_to_name = {
  28. "afr": "Afrikaans",
  29. "amh": "Amharic",
  30. "arb": "Modern Standard Arabic",
  31. "ary": "Moroccan Arabic",
  32. "arz": "Egyptian Arabic",
  33. "asm": "Assamese",
  34. "ast": "Asturian",
  35. "azj": "North Azerbaijani",
  36. "bel": "Belarusian",
  37. "ben": "Bengali",
  38. "bos": "Bosnian",
  39. "bul": "Bulgarian",
  40. "cat": "Catalan",
  41. "ceb": "Cebuano",
  42. "ces": "Czech",
  43. "ckb": "Central Kurdish",
  44. "cmn": "Mandarin Chinese",
  45. "cym": "Welsh",
  46. "dan": "Danish",
  47. "deu": "German",
  48. "ell": "Greek",
  49. "eng": "English",
  50. "est": "Estonian",
  51. "eus": "Basque",
  52. "fin": "Finnish",
  53. "fra": "French",
  54. "gaz": "West Central Oromo",
  55. "gle": "Irish",
  56. "glg": "Galician",
  57. "guj": "Gujarati",
  58. "heb": "Hebrew",
  59. "hin": "Hindi",
  60. "hrv": "Croatian",
  61. "hun": "Hungarian",
  62. "hye": "Armenian",
  63. "ibo": "Igbo",
  64. "ind": "Indonesian",
  65. "isl": "Icelandic",
  66. "ita": "Italian",
  67. "jav": "Javanese",
  68. "jpn": "Japanese",
  69. "kam": "Kamba",
  70. "kan": "Kannada",
  71. "kat": "Georgian",
  72. "kaz": "Kazakh",
  73. "kea": "Kabuverdianu",
  74. "khk": "Halh Mongolian",
  75. "khm": "Khmer",
  76. "kir": "Kyrgyz",
  77. "kor": "Korean",
  78. "lao": "Lao",
  79. "lit": "Lithuanian",
  80. "ltz": "Luxembourgish",
  81. "lug": "Ganda",
  82. "luo": "Luo",
  83. "lvs": "Standard Latvian",
  84. "mai": "Maithili",
  85. "mal": "Malayalam",
  86. "mar": "Marathi",
  87. "mkd": "Macedonian",
  88. "mlt": "Maltese",
  89. "mni": "Meitei",
  90. "mya": "Burmese",
  91. "nld": "Dutch",
  92. "nno": "Norwegian Nynorsk",
  93. "nob": "Norwegian Bokm\u00e5l",
  94. "npi": "Nepali",
  95. "nya": "Nyanja",
  96. "oci": "Occitan",
  97. "ory": "Odia",
  98. "pan": "Punjabi",
  99. "pbt": "Southern Pashto",
  100. "pes": "Western Persian",
  101. "pol": "Polish",
  102. "por": "Portuguese",
  103. "ron": "Romanian",
  104. "rus": "Russian",
  105. "slk": "Slovak",
  106. "slv": "Slovenian",
  107. "sna": "Shona",
  108. "snd": "Sindhi",
  109. "som": "Somali",
  110. "spa": "Spanish",
  111. "srp": "Serbian",
  112. "swe": "Swedish",
  113. "swh": "Swahili",
  114. "tam": "Tamil",
  115. "tel": "Telugu",
  116. "tgk": "Tajik",
  117. "tgl": "Tagalog",
  118. "tha": "Thai",
  119. "tur": "Turkish",
  120. "ukr": "Ukrainian",
  121. "urd": "Urdu",
  122. "uzn": "Northern Uzbek",
  123. "vie": "Vietnamese",
  124. "xho": "Xhosa",
  125. "yor": "Yoruba",
  126. "yue": "Cantonese",
  127. "zlm": "Colloquial Malay",
  128. "zsm": "Standard Malay",
  129. "zul": "Zulu",
  130. }
  131. LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
  132. # Source langs: S2ST / S2TT / ASR don't need source lang
  133. # T2TT / T2ST use this
  134. text_source_language_codes = [
  135. "afr",
  136. "amh",
  137. "arb",
  138. "ary",
  139. "arz",
  140. "asm",
  141. "azj",
  142. "bel",
  143. "ben",
  144. "bos",
  145. "bul",
  146. "cat",
  147. "ceb",
  148. "ces",
  149. "ckb",
  150. "cmn",
  151. "cym",
  152. "dan",
  153. "deu",
  154. "ell",
  155. "eng",
  156. "est",
  157. "eus",
  158. "fin",
  159. "fra",
  160. "gaz",
  161. "gle",
  162. "glg",
  163. "guj",
  164. "heb",
  165. "hin",
  166. "hrv",
  167. "hun",
  168. "hye",
  169. "ibo",
  170. "ind",
  171. "isl",
  172. "ita",
  173. "jav",
  174. "jpn",
  175. "kan",
  176. "kat",
  177. "kaz",
  178. "khk",
  179. "khm",
  180. "kir",
  181. "kor",
  182. "lao",
  183. "lit",
  184. "lug",
  185. "luo",
  186. "lvs",
  187. "mai",
  188. "mal",
  189. "mar",
  190. "mkd",
  191. "mlt",
  192. "mni",
  193. "mya",
  194. "nld",
  195. "nno",
  196. "nob",
  197. "npi",
  198. "nya",
  199. "ory",
  200. "pan",
  201. "pbt",
  202. "pes",
  203. "pol",
  204. "por",
  205. "ron",
  206. "rus",
  207. "slk",
  208. "slv",
  209. "sna",
  210. "snd",
  211. "som",
  212. "spa",
  213. "srp",
  214. "swe",
  215. "swh",
  216. "tam",
  217. "tel",
  218. "tgk",
  219. "tgl",
  220. "tha",
  221. "tur",
  222. "ukr",
  223. "urd",
  224. "uzn",
  225. "vie",
  226. "yor",
  227. "yue",
  228. "zsm",
  229. "zul",
  230. ]
  231. TEXT_SOURCE_LANGUAGE_NAMES = sorted(
  232. [language_code_to_name[code] for code in text_source_language_codes]
  233. )
  234. # Target langs:
  235. # S2ST / T2ST
  236. s2st_target_language_codes = [
  237. "eng",
  238. "arb",
  239. "ben",
  240. "cat",
  241. "ces",
  242. "cmn",
  243. "cym",
  244. "dan",
  245. "deu",
  246. "est",
  247. "fin",
  248. "fra",
  249. "hin",
  250. "ind",
  251. "ita",
  252. "jpn",
  253. "kor",
  254. "mlt",
  255. "nld",
  256. "pes",
  257. "pol",
  258. "por",
  259. "ron",
  260. "rus",
  261. "slk",
  262. "spa",
  263. "swe",
  264. "swh",
  265. "tel",
  266. "tgl",
  267. "tha",
  268. "tur",
  269. "ukr",
  270. "urd",
  271. "uzn",
  272. "vie",
  273. ]
  274. S2ST_TARGET_LANGUAGE_NAMES = sorted(
  275. [language_code_to_name[code] for code in s2st_target_language_codes]
  276. )
  277. # S2TT / ASR
  278. S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
  279. # T2TT
  280. T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
  281. # Download sample input audio files
  282. filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
  283. for filename in filenames:
  284. hf_hub_download(
  285. repo_id="facebook/seamless_m4t",
  286. repo_type="space",
  287. filename=filename,
  288. local_dir=".",
  289. )
  290. AUDIO_SAMPLE_RATE = 16000.0
  291. MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
  292. DEFAULT_TARGET_LANGUAGE = "French"
  293. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  294. translator = Translator(
  295. model_name_or_card="seamlessM4T_large",
  296. vocoder_name_or_card="vocoder_36langs",
  297. device=device,
  298. dtype=torch.float16 if "cuda" in device.type else torch.float32,
  299. )
  300. def predict(
  301. task_name: str,
  302. audio_source: str,
  303. input_audio_mic: str | None,
  304. input_audio_file: str | None,
  305. input_text: str | None,
  306. source_language: str | None,
  307. target_language: str,
  308. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  309. task_name = task_name.split()[0]
  310. source_language_code = (
  311. LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
  312. )
  313. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  314. if task_name in ["S2ST", "S2TT", "ASR"]:
  315. if audio_source == "microphone":
  316. input_data = input_audio_mic
  317. else:
  318. input_data = input_audio_file
  319. arr, org_sr = torchaudio.load(input_data)
  320. new_arr = torchaudio.functional.resample(
  321. arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE
  322. )
  323. max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
  324. if new_arr.shape[1] > max_length:
  325. new_arr = new_arr[:, :max_length]
  326. gr.Warning(
  327. f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used."
  328. )
  329. torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
  330. else:
  331. input_data = input_text
  332. text_out, wav, sr = translator.predict(
  333. input=input_data,
  334. task_str=task_name,
  335. tgt_lang=target_language_code,
  336. src_lang=source_language_code,
  337. ngram_filtering=True,
  338. )
  339. if task_name in ["S2ST", "T2ST"]:
  340. return (sr, wav.cpu().detach().numpy()), text_out
  341. else:
  342. return None, text_out
  343. def process_s2st_example(
  344. input_audio_file: str, target_language: str
  345. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  346. return predict(
  347. task_name="S2ST",
  348. audio_source="file",
  349. input_audio_mic=None,
  350. input_audio_file=input_audio_file,
  351. input_text=None,
  352. source_language=None,
  353. target_language=target_language,
  354. )
  355. def process_s2tt_example(
  356. input_audio_file: str, target_language: str
  357. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  358. return predict(
  359. task_name="S2TT",
  360. audio_source="file",
  361. input_audio_mic=None,
  362. input_audio_file=input_audio_file,
  363. input_text=None,
  364. source_language=None,
  365. target_language=target_language,
  366. )
  367. def process_t2st_example(
  368. input_text: str, source_language: str, target_language: str
  369. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  370. return predict(
  371. task_name="T2ST",
  372. audio_source="",
  373. input_audio_mic=None,
  374. input_audio_file=None,
  375. input_text=input_text,
  376. source_language=source_language,
  377. target_language=target_language,
  378. )
  379. def process_t2tt_example(
  380. input_text: str, source_language: str, target_language: str
  381. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  382. return predict(
  383. task_name="T2TT",
  384. audio_source="",
  385. input_audio_mic=None,
  386. input_audio_file=None,
  387. input_text=input_text,
  388. source_language=source_language,
  389. target_language=target_language,
  390. )
  391. def process_asr_example(
  392. input_audio_file: str, target_language: str
  393. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  394. return predict(
  395. task_name="ASR",
  396. audio_source="file",
  397. input_audio_mic=None,
  398. input_audio_file=input_audio_file,
  399. input_text=None,
  400. source_language=None,
  401. target_language=target_language,
  402. )
  403. def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
  404. mic = audio_source == "microphone"
  405. return (
  406. gr.update(visible=mic, value=None), # input_audio_mic
  407. gr.update(visible=not mic, value=None), # input_audio_file
  408. )
  409. def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
  410. task_name = task_name.split()[0]
  411. if task_name == "S2ST":
  412. return (
  413. gr.update(visible=True), # audio_box
  414. gr.update(visible=False), # input_text
  415. gr.update(visible=False), # source_language
  416. gr.update(
  417. visible=True,
  418. choices=S2ST_TARGET_LANGUAGE_NAMES,
  419. value=DEFAULT_TARGET_LANGUAGE,
  420. ), # target_language
  421. )
  422. elif task_name == "S2TT":
  423. return (
  424. gr.update(visible=True), # audio_box
  425. gr.update(visible=False), # input_text
  426. gr.update(visible=False), # source_language
  427. gr.update(
  428. visible=True,
  429. choices=S2TT_TARGET_LANGUAGE_NAMES,
  430. value=DEFAULT_TARGET_LANGUAGE,
  431. ), # target_language
  432. )
  433. elif task_name == "T2ST":
  434. return (
  435. gr.update(visible=False), # audio_box
  436. gr.update(visible=True), # input_text
  437. gr.update(visible=True), # source_language
  438. gr.update(
  439. visible=True,
  440. choices=S2ST_TARGET_LANGUAGE_NAMES,
  441. value=DEFAULT_TARGET_LANGUAGE,
  442. ), # target_language
  443. )
  444. elif task_name == "T2TT":
  445. return (
  446. gr.update(visible=False), # audio_box
  447. gr.update(visible=True), # input_text
  448. gr.update(visible=True), # source_language
  449. gr.update(
  450. visible=True,
  451. choices=T2TT_TARGET_LANGUAGE_NAMES,
  452. value=DEFAULT_TARGET_LANGUAGE,
  453. ), # target_language
  454. )
  455. elif task_name == "ASR":
  456. return (
  457. gr.update(visible=True), # audio_box
  458. gr.update(visible=False), # input_text
  459. gr.update(visible=False), # source_language
  460. gr.update(
  461. visible=True,
  462. choices=S2TT_TARGET_LANGUAGE_NAMES,
  463. value=DEFAULT_TARGET_LANGUAGE,
  464. ), # target_language
  465. )
  466. else:
  467. raise ValueError(f"Unknown task: {task_name}")
  468. def update_output_ui(task_name: str) -> tuple[dict, dict]:
  469. task_name = task_name.split()[0]
  470. if task_name in ["S2ST", "T2ST"]:
  471. return (
  472. gr.update(visible=True, value=None), # output_audio
  473. gr.update(value=None), # output_text
  474. )
  475. elif task_name in ["S2TT", "T2TT", "ASR"]:
  476. return (
  477. gr.update(visible=False, value=None), # output_audio
  478. gr.update(value=None), # output_text
  479. )
  480. else:
  481. raise ValueError(f"Unknown task: {task_name}")
  482. def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
  483. task_name = task_name.split()[0]
  484. return (
  485. gr.update(visible=task_name == "S2ST"), # s2st_example_row
  486. gr.update(visible=task_name == "S2TT"), # s2tt_example_row
  487. gr.update(visible=task_name == "T2ST"), # t2st_example_row
  488. gr.update(visible=task_name == "T2TT"), # t2tt_example_row
  489. gr.update(visible=task_name == "ASR"), # asr_example_row
  490. )
  491. css = """
  492. h1 {
  493. text-align: center;
  494. }
  495. .contain {
  496. max-width: 730px;
  497. margin: auto;
  498. padding-top: 1.5rem;
  499. }
  500. """
  501. with gr.Blocks(css=css) as demo:
  502. gr.Markdown(DESCRIPTION)
  503. with gr.Group():
  504. task_name = gr.Dropdown(
  505. label="Task",
  506. choices=TASK_NAMES,
  507. value=TASK_NAMES[0],
  508. )
  509. with gr.Row():
  510. source_language = gr.Dropdown(
  511. label="Source language",
  512. choices=TEXT_SOURCE_LANGUAGE_NAMES,
  513. value="English",
  514. visible=False,
  515. )
  516. target_language = gr.Dropdown(
  517. label="Target language",
  518. choices=S2ST_TARGET_LANGUAGE_NAMES,
  519. value=DEFAULT_TARGET_LANGUAGE,
  520. )
  521. with gr.Row() as audio_box:
  522. audio_source = gr.Radio(
  523. label="Audio source",
  524. choices=["file", "microphone"],
  525. value="file",
  526. )
  527. input_audio_mic = gr.Audio(
  528. label="Input speech",
  529. type="filepath",
  530. source="microphone",
  531. visible=False,
  532. )
  533. input_audio_file = gr.Audio(
  534. label="Input speech",
  535. type="filepath",
  536. source="upload",
  537. visible=True,
  538. )
  539. input_text = gr.Textbox(label="Input text", visible=False)
  540. btn = gr.Button("Translate")
  541. with gr.Column():
  542. output_audio = gr.Audio(
  543. label="Translated speech",
  544. autoplay=False,
  545. streaming=False,
  546. type="numpy",
  547. )
  548. output_text = gr.Textbox(label="Translated text")
  549. with gr.Row(visible=True) as s2st_example_row:
  550. s2st_examples = gr.Examples(
  551. examples=[
  552. ["assets/sample_input.mp3", "French"],
  553. ["assets/sample_input.mp3", "Mandarin Chinese"],
  554. ["assets/sample_input_2.mp3", "Hindi"],
  555. ["assets/sample_input_2.mp3", "Spanish"],
  556. ],
  557. inputs=[input_audio_file, target_language],
  558. outputs=[output_audio, output_text],
  559. fn=process_s2st_example,
  560. )
  561. with gr.Row(visible=False) as s2tt_example_row:
  562. s2tt_examples = gr.Examples(
  563. examples=[
  564. ["assets/sample_input.mp3", "French"],
  565. ["assets/sample_input.mp3", "Mandarin Chinese"],
  566. ["assets/sample_input_2.mp3", "Hindi"],
  567. ["assets/sample_input_2.mp3", "Spanish"],
  568. ],
  569. inputs=[input_audio_file, target_language],
  570. outputs=[output_audio, output_text],
  571. fn=process_s2tt_example,
  572. )
  573. with gr.Row(visible=False) as t2st_example_row:
  574. t2st_examples = gr.Examples(
  575. examples=[
  576. ["My favorite animal is the elephant.", "English", "French"],
  577. ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
  578. [
  579. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  580. "English",
  581. "Hindi",
  582. ],
  583. [
  584. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  585. "English",
  586. "Spanish",
  587. ],
  588. ],
  589. inputs=[input_text, source_language, target_language],
  590. outputs=[output_audio, output_text],
  591. fn=process_t2st_example,
  592. )
  593. with gr.Row(visible=False) as t2tt_example_row:
  594. t2tt_examples = gr.Examples(
  595. examples=[
  596. ["My favorite animal is the elephant.", "English", "French"],
  597. ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
  598. [
  599. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  600. "English",
  601. "Hindi",
  602. ],
  603. [
  604. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  605. "English",
  606. "Spanish",
  607. ],
  608. ],
  609. inputs=[input_text, source_language, target_language],
  610. outputs=[output_audio, output_text],
  611. fn=process_t2tt_example,
  612. )
  613. with gr.Row(visible=False) as asr_example_row:
  614. asr_examples = gr.Examples(
  615. examples=[
  616. ["assets/sample_input.mp3", "English"],
  617. ["assets/sample_input_2.mp3", "English"],
  618. ],
  619. inputs=[input_audio_file, target_language],
  620. outputs=[output_audio, output_text],
  621. fn=process_asr_example,
  622. )
  623. audio_source.change(
  624. fn=update_audio_ui,
  625. inputs=audio_source,
  626. outputs=[
  627. input_audio_mic,
  628. input_audio_file,
  629. ],
  630. queue=False,
  631. api_name=False,
  632. )
  633. task_name.change(
  634. fn=update_input_ui,
  635. inputs=task_name,
  636. outputs=[
  637. audio_box,
  638. input_text,
  639. source_language,
  640. target_language,
  641. ],
  642. queue=False,
  643. api_name=False,
  644. ).then(
  645. fn=update_output_ui,
  646. inputs=task_name,
  647. outputs=[output_audio, output_text],
  648. queue=False,
  649. api_name=False,
  650. ).then(
  651. fn=update_example_ui,
  652. inputs=task_name,
  653. outputs=[
  654. s2st_example_row,
  655. s2tt_example_row,
  656. t2st_example_row,
  657. t2tt_example_row,
  658. asr_example_row,
  659. ],
  660. queue=False,
  661. api_name=False,
  662. )
  663. btn.click(
  664. fn=predict,
  665. inputs=[
  666. task_name,
  667. audio_source,
  668. input_audio_mic,
  669. input_audio_file,
  670. input_text,
  671. source_language,
  672. target_language,
  673. ],
  674. outputs=[output_audio, output_text],
  675. api_name="run",
  676. )
  677. if __name__ == "__main__":
  678. demo.queue().launch()