app.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. from __future__ import annotations
  2. import gradio as gr
  3. import numpy as np
  4. import torch
  5. import torchaudio
  6. from huggingface_hub import hf_hub_download
  7. from seamless_communication.models.inference.translator import Translator
  8. DESCRIPTION = """# SeamlessM4T
  9. [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
  10. translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
  11. This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
  12. translation and more, without relying on multiple separate models.
  13. """
  14. TASK_NAMES = [
  15. "S2ST (Speech to Speech translation)",
  16. "S2TT (Speech to Text translation)",
  17. "T2ST (Text to Speech translation)",
  18. "T2TT (Text to Text translation)",
  19. "ASR (Automatic Speech Recognition)",
  20. ]
  21. # Language dict
  22. language_code_to_name = {
  23. "afr": "Afrikaans",
  24. "amh": "Amharic",
  25. "arb": "Modern Standard Arabic",
  26. "ary": "Moroccan Arabic",
  27. "arz": "Egyptian Arabic",
  28. "asm": "Assamese",
  29. "ast": "Asturian",
  30. "azj": "North Azerbaijani",
  31. "bel": "Belarusian",
  32. "ben": "Bengali",
  33. "bos": "Bosnian",
  34. "bul": "Bulgarian",
  35. "cat": "Catalan",
  36. "ceb": "Cebuano",
  37. "ces": "Czech",
  38. "ckb": "Central Kurdish",
  39. "cmn": "Mandarin Chinese",
  40. "cym": "Welsh",
  41. "dan": "Danish",
  42. "deu": "German",
  43. "ell": "Greek",
  44. "eng": "English",
  45. "est": "Estonian",
  46. "eus": "Basque",
  47. "fin": "Finnish",
  48. "fra": "French",
  49. "gaz": "West Central Oromo",
  50. "gle": "Irish",
  51. "glg": "Galician",
  52. "guj": "Gujarati",
  53. "heb": "Hebrew",
  54. "hin": "Hindi",
  55. "hrv": "Croatian",
  56. "hun": "Hungarian",
  57. "hye": "Armenian",
  58. "ibo": "Igbo",
  59. "ind": "Indonesian",
  60. "isl": "Icelandic",
  61. "ita": "Italian",
  62. "jav": "Javanese",
  63. "jpn": "Japanese",
  64. "kam": "Kamba",
  65. "kan": "Kannada",
  66. "kat": "Georgian",
  67. "kaz": "Kazakh",
  68. "kea": "Kabuverdianu",
  69. "khk": "Halh Mongolian",
  70. "khm": "Khmer",
  71. "kir": "Kyrgyz",
  72. "kor": "Korean",
  73. "lao": "Lao",
  74. "lit": "Lithuanian",
  75. "ltz": "Luxembourgish",
  76. "lug": "Ganda",
  77. "luo": "Luo",
  78. "lvs": "Standard Latvian",
  79. "mai": "Maithili",
  80. "mal": "Malayalam",
  81. "mar": "Marathi",
  82. "mkd": "Macedonian",
  83. "mlt": "Maltese",
  84. "mni": "Meitei",
  85. "mya": "Burmese",
  86. "nld": "Dutch",
  87. "nno": "Norwegian Nynorsk",
  88. "nob": "Norwegian Bokm\u00e5l",
  89. "npi": "Nepali",
  90. "nya": "Nyanja",
  91. "oci": "Occitan",
  92. "ory": "Odia",
  93. "pan": "Punjabi",
  94. "pbt": "Southern Pashto",
  95. "pes": "Western Persian",
  96. "pol": "Polish",
  97. "por": "Portuguese",
  98. "ron": "Romanian",
  99. "rus": "Russian",
  100. "slk": "Slovak",
  101. "slv": "Slovenian",
  102. "sna": "Shona",
  103. "snd": "Sindhi",
  104. "som": "Somali",
  105. "spa": "Spanish",
  106. "srp": "Serbian",
  107. "swe": "Swedish",
  108. "swh": "Swahili",
  109. "tam": "Tamil",
  110. "tel": "Telugu",
  111. "tgk": "Tajik",
  112. "tgl": "Tagalog",
  113. "tha": "Thai",
  114. "tur": "Turkish",
  115. "ukr": "Ukrainian",
  116. "urd": "Urdu",
  117. "uzn": "Northern Uzbek",
  118. "vie": "Vietnamese",
  119. "xho": "Xhosa",
  120. "yor": "Yoruba",
  121. "yue": "Cantonese",
  122. "zlm": "Colloquial Malay",
  123. "zsm": "Standard Malay",
  124. "zul": "Zulu",
  125. }
  126. LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
  127. # Source langs: S2ST / S2TT / ASR don't need source lang
  128. # T2TT / T2ST use this
  129. text_source_language_codes = [
  130. "afr",
  131. "amh",
  132. "arb",
  133. "ary",
  134. "arz",
  135. "asm",
  136. "azj",
  137. "bel",
  138. "ben",
  139. "bos",
  140. "bul",
  141. "cat",
  142. "ceb",
  143. "ces",
  144. "ckb",
  145. "cmn",
  146. "cym",
  147. "dan",
  148. "deu",
  149. "ell",
  150. "eng",
  151. "est",
  152. "eus",
  153. "fin",
  154. "fra",
  155. "gaz",
  156. "gle",
  157. "glg",
  158. "guj",
  159. "heb",
  160. "hin",
  161. "hrv",
  162. "hun",
  163. "hye",
  164. "ibo",
  165. "ind",
  166. "isl",
  167. "ita",
  168. "jav",
  169. "jpn",
  170. "kan",
  171. "kat",
  172. "kaz",
  173. "khk",
  174. "khm",
  175. "kir",
  176. "kor",
  177. "lao",
  178. "lit",
  179. "lug",
  180. "luo",
  181. "lvs",
  182. "mai",
  183. "mal",
  184. "mar",
  185. "mkd",
  186. "mlt",
  187. "mni",
  188. "mya",
  189. "nld",
  190. "nno",
  191. "nob",
  192. "npi",
  193. "nya",
  194. "ory",
  195. "pan",
  196. "pbt",
  197. "pes",
  198. "pol",
  199. "por",
  200. "ron",
  201. "rus",
  202. "slk",
  203. "slv",
  204. "sna",
  205. "snd",
  206. "som",
  207. "spa",
  208. "srp",
  209. "swe",
  210. "swh",
  211. "tam",
  212. "tel",
  213. "tgk",
  214. "tgl",
  215. "tha",
  216. "tur",
  217. "ukr",
  218. "urd",
  219. "uzn",
  220. "vie",
  221. "yor",
  222. "yue",
  223. "zsm",
  224. "zul",
  225. ]
  226. TEXT_SOURCE_LANGUAGE_NAMES = sorted(
  227. [language_code_to_name[code] for code in text_source_language_codes]
  228. )
  229. # Target langs:
  230. # S2ST / T2ST
  231. s2st_target_language_codes = [
  232. "eng",
  233. "arb",
  234. "ben",
  235. "cat",
  236. "ces",
  237. "cmn",
  238. "cym",
  239. "dan",
  240. "deu",
  241. "est",
  242. "fin",
  243. "fra",
  244. "hin",
  245. "ind",
  246. "ita",
  247. "jpn",
  248. "kor",
  249. "mlt",
  250. "nld",
  251. "pes",
  252. "pol",
  253. "por",
  254. "ron",
  255. "rus",
  256. "slk",
  257. "spa",
  258. "swe",
  259. "swh",
  260. "tel",
  261. "tgl",
  262. "tha",
  263. "tur",
  264. "ukr",
  265. "urd",
  266. "uzn",
  267. "vie",
  268. ]
  269. S2ST_TARGET_LANGUAGE_NAMES = sorted(
  270. [language_code_to_name[code] for code in s2st_target_language_codes]
  271. )
  272. # S2TT / ASR
  273. S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
  274. # T2TT
  275. T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
  276. # Download sample input audio files
  277. filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
  278. for filename in filenames:
  279. hf_hub_download(
  280. repo_id="facebook/seamless_m4t",
  281. repo_type="space",
  282. filename=filename,
  283. local_dir=".",
  284. )
  285. AUDIO_SAMPLE_RATE = 16000.0
  286. MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
  287. DEFAULT_TARGET_LANGUAGE = "French"
  288. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  289. translator = Translator(
  290. model_name_or_card="seamlessM4T_large",
  291. vocoder_name_or_card="vocoder_36langs",
  292. device=device,
  293. dtype=torch.float16 if "cuda" in device.type else torch.float32,
  294. )
  295. def predict(
  296. task_name: str,
  297. audio_source: str,
  298. input_audio_mic: str | None,
  299. input_audio_file: str | None,
  300. input_text: str | None,
  301. source_language: str | None,
  302. target_language: str,
  303. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  304. task_name = task_name.split()[0]
  305. source_language_code = (
  306. LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
  307. )
  308. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  309. if task_name in ["S2ST", "S2TT", "ASR"]:
  310. if audio_source == "microphone":
  311. input_data = input_audio_mic
  312. else:
  313. input_data = input_audio_file
  314. arr, org_sr = torchaudio.load(input_data)
  315. new_arr = torchaudio.functional.resample(
  316. arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE
  317. )
  318. max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
  319. if new_arr.shape[1] > max_length:
  320. new_arr = new_arr[:, :max_length]
  321. gr.Warning(
  322. f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used."
  323. )
  324. torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
  325. else:
  326. input_data = input_text
  327. text_out, wav, sr = translator.predict(
  328. input=input_data,
  329. task_str=task_name,
  330. tgt_lang=target_language_code,
  331. src_lang=source_language_code,
  332. ngram_filtering=True,
  333. )
  334. if task_name in ["S2ST", "T2ST"]:
  335. return (sr, wav.cpu().detach().numpy()), text_out
  336. else:
  337. return None, text_out
  338. def process_s2st_example(
  339. input_audio_file: str, target_language: str
  340. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  341. return predict(
  342. task_name="S2ST",
  343. audio_source="file",
  344. input_audio_mic=None,
  345. input_audio_file=input_audio_file,
  346. input_text=None,
  347. source_language=None,
  348. target_language=target_language,
  349. )
  350. def process_s2tt_example(
  351. input_audio_file: str, target_language: str
  352. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  353. return predict(
  354. task_name="S2TT",
  355. audio_source="file",
  356. input_audio_mic=None,
  357. input_audio_file=input_audio_file,
  358. input_text=None,
  359. source_language=None,
  360. target_language=target_language,
  361. )
  362. def process_t2st_example(
  363. input_text: str, source_language: str, target_language: str
  364. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  365. return predict(
  366. task_name="T2ST",
  367. audio_source="",
  368. input_audio_mic=None,
  369. input_audio_file=None,
  370. input_text=input_text,
  371. source_language=source_language,
  372. target_language=target_language,
  373. )
  374. def process_t2tt_example(
  375. input_text: str, source_language: str, target_language: str
  376. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  377. return predict(
  378. task_name="T2TT",
  379. audio_source="",
  380. input_audio_mic=None,
  381. input_audio_file=None,
  382. input_text=input_text,
  383. source_language=source_language,
  384. target_language=target_language,
  385. )
  386. def process_asr_example(
  387. input_audio_file: str, target_language: str
  388. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  389. return predict(
  390. task_name="ASR",
  391. audio_source="file",
  392. input_audio_mic=None,
  393. input_audio_file=input_audio_file,
  394. input_text=None,
  395. source_language=None,
  396. target_language=target_language,
  397. )
  398. def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
  399. mic = audio_source == "microphone"
  400. return (
  401. gr.update(visible=mic, value=None), # input_audio_mic
  402. gr.update(visible=not mic, value=None), # input_audio_file
  403. )
  404. def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
  405. task_name = task_name.split()[0]
  406. if task_name == "S2ST":
  407. return (
  408. gr.update(visible=True), # audio_box
  409. gr.update(visible=False), # input_text
  410. gr.update(visible=False), # source_language
  411. gr.update(
  412. visible=True,
  413. choices=S2ST_TARGET_LANGUAGE_NAMES,
  414. value=DEFAULT_TARGET_LANGUAGE,
  415. ), # target_language
  416. )
  417. elif task_name == "S2TT":
  418. return (
  419. gr.update(visible=True), # audio_box
  420. gr.update(visible=False), # input_text
  421. gr.update(visible=False), # source_language
  422. gr.update(
  423. visible=True,
  424. choices=S2TT_TARGET_LANGUAGE_NAMES,
  425. value=DEFAULT_TARGET_LANGUAGE,
  426. ), # target_language
  427. )
  428. elif task_name == "T2ST":
  429. return (
  430. gr.update(visible=False), # audio_box
  431. gr.update(visible=True), # input_text
  432. gr.update(visible=True), # source_language
  433. gr.update(
  434. visible=True,
  435. choices=S2ST_TARGET_LANGUAGE_NAMES,
  436. value=DEFAULT_TARGET_LANGUAGE,
  437. ), # target_language
  438. )
  439. elif task_name == "T2TT":
  440. return (
  441. gr.update(visible=False), # audio_box
  442. gr.update(visible=True), # input_text
  443. gr.update(visible=True), # source_language
  444. gr.update(
  445. visible=True,
  446. choices=T2TT_TARGET_LANGUAGE_NAMES,
  447. value=DEFAULT_TARGET_LANGUAGE,
  448. ), # target_language
  449. )
  450. elif task_name == "ASR":
  451. return (
  452. gr.update(visible=True), # audio_box
  453. gr.update(visible=False), # input_text
  454. gr.update(visible=False), # source_language
  455. gr.update(
  456. visible=True,
  457. choices=S2TT_TARGET_LANGUAGE_NAMES,
  458. value=DEFAULT_TARGET_LANGUAGE,
  459. ), # target_language
  460. )
  461. else:
  462. raise ValueError(f"Unknown task: {task_name}")
  463. def update_output_ui(task_name: str) -> tuple[dict, dict]:
  464. task_name = task_name.split()[0]
  465. if task_name in ["S2ST", "T2ST"]:
  466. return (
  467. gr.update(visible=True, value=None), # output_audio
  468. gr.update(value=None), # output_text
  469. )
  470. elif task_name in ["S2TT", "T2TT", "ASR"]:
  471. return (
  472. gr.update(visible=False, value=None), # output_audio
  473. gr.update(value=None), # output_text
  474. )
  475. else:
  476. raise ValueError(f"Unknown task: {task_name}")
  477. def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
  478. task_name = task_name.split()[0]
  479. return (
  480. gr.update(visible=task_name == "S2ST"), # s2st_example_row
  481. gr.update(visible=task_name == "S2TT"), # s2tt_example_row
  482. gr.update(visible=task_name == "T2ST"), # t2st_example_row
  483. gr.update(visible=task_name == "T2TT"), # t2tt_example_row
  484. gr.update(visible=task_name == "ASR"), # asr_example_row
  485. )
  486. css = """
  487. h1 {
  488. text-align: center;
  489. }
  490. .contain {
  491. max-width: 730px;
  492. margin: auto;
  493. padding-top: 1.5rem;
  494. }
  495. """
  496. with gr.Blocks(css=css) as demo:
  497. gr.Markdown(DESCRIPTION)
  498. with gr.Group():
  499. task_name = gr.Dropdown(
  500. label="Task",
  501. choices=TASK_NAMES,
  502. value=TASK_NAMES[0],
  503. )
  504. with gr.Row():
  505. source_language = gr.Dropdown(
  506. label="Source language",
  507. choices=TEXT_SOURCE_LANGUAGE_NAMES,
  508. value="English",
  509. visible=False,
  510. )
  511. target_language = gr.Dropdown(
  512. label="Target language",
  513. choices=S2ST_TARGET_LANGUAGE_NAMES,
  514. value=DEFAULT_TARGET_LANGUAGE,
  515. )
  516. with gr.Row() as audio_box:
  517. audio_source = gr.Radio(
  518. label="Audio source",
  519. choices=["file", "microphone"],
  520. value="file",
  521. )
  522. input_audio_mic = gr.Audio(
  523. label="Input speech",
  524. type="filepath",
  525. source="microphone",
  526. visible=False,
  527. )
  528. input_audio_file = gr.Audio(
  529. label="Input speech",
  530. type="filepath",
  531. source="upload",
  532. visible=True,
  533. )
  534. input_text = gr.Textbox(label="Input text", visible=False)
  535. btn = gr.Button("Translate")
  536. with gr.Column():
  537. output_audio = gr.Audio(
  538. label="Translated speech",
  539. autoplay=False,
  540. streaming=False,
  541. type="numpy",
  542. )
  543. output_text = gr.Textbox(label="Translated text")
  544. with gr.Row(visible=True) as s2st_example_row:
  545. s2st_examples = gr.Examples(
  546. examples=[
  547. ["assets/sample_input.mp3", "French"],
  548. ["assets/sample_input.mp3", "Mandarin Chinese"],
  549. ["assets/sample_input_2.mp3", "Hindi"],
  550. ["assets/sample_input_2.mp3", "Spanish"],
  551. ],
  552. inputs=[input_audio_file, target_language],
  553. outputs=[output_audio, output_text],
  554. fn=process_s2st_example,
  555. )
  556. with gr.Row(visible=False) as s2tt_example_row:
  557. s2tt_examples = gr.Examples(
  558. examples=[
  559. ["assets/sample_input.mp3", "French"],
  560. ["assets/sample_input.mp3", "Mandarin Chinese"],
  561. ["assets/sample_input_2.mp3", "Hindi"],
  562. ["assets/sample_input_2.mp3", "Spanish"],
  563. ],
  564. inputs=[input_audio_file, target_language],
  565. outputs=[output_audio, output_text],
  566. fn=process_s2tt_example,
  567. )
  568. with gr.Row(visible=False) as t2st_example_row:
  569. t2st_examples = gr.Examples(
  570. examples=[
  571. ["My favorite animal is the elephant.", "English", "French"],
  572. ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
  573. [
  574. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  575. "English",
  576. "Hindi",
  577. ],
  578. [
  579. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  580. "English",
  581. "Spanish",
  582. ],
  583. ],
  584. inputs=[input_text, source_language, target_language],
  585. outputs=[output_audio, output_text],
  586. fn=process_t2st_example,
  587. )
  588. with gr.Row(visible=False) as t2tt_example_row:
  589. t2tt_examples = gr.Examples(
  590. examples=[
  591. ["My favorite animal is the elephant.", "English", "French"],
  592. ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
  593. [
  594. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  595. "English",
  596. "Hindi",
  597. ],
  598. [
  599. "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
  600. "English",
  601. "Spanish",
  602. ],
  603. ],
  604. inputs=[input_text, source_language, target_language],
  605. outputs=[output_audio, output_text],
  606. fn=process_t2tt_example,
  607. )
  608. with gr.Row(visible=False) as asr_example_row:
  609. asr_examples = gr.Examples(
  610. examples=[
  611. ["assets/sample_input.mp3", "English"],
  612. ["assets/sample_input_2.mp3", "English"],
  613. ],
  614. inputs=[input_audio_file, target_language],
  615. outputs=[output_audio, output_text],
  616. fn=process_asr_example,
  617. )
  618. audio_source.change(
  619. fn=update_audio_ui,
  620. inputs=audio_source,
  621. outputs=[
  622. input_audio_mic,
  623. input_audio_file,
  624. ],
  625. queue=False,
  626. api_name=False,
  627. )
  628. task_name.change(
  629. fn=update_input_ui,
  630. inputs=task_name,
  631. outputs=[
  632. audio_box,
  633. input_text,
  634. source_language,
  635. target_language,
  636. ],
  637. queue=False,
  638. api_name=False,
  639. ).then(
  640. fn=update_output_ui,
  641. inputs=task_name,
  642. outputs=[output_audio, output_text],
  643. queue=False,
  644. api_name=False,
  645. ).then(
  646. fn=update_example_ui,
  647. inputs=task_name,
  648. outputs=[
  649. s2st_example_row,
  650. s2tt_example_row,
  651. t2st_example_row,
  652. t2tt_example_row,
  653. asr_example_row,
  654. ],
  655. queue=False,
  656. api_name=False,
  657. )
  658. btn.click(
  659. fn=predict,
  660. inputs=[
  661. task_name,
  662. audio_source,
  663. input_audio_mic,
  664. input_audio_file,
  665. input_text,
  666. source_language,
  667. target_language,
  668. ],
  669. outputs=[output_audio, output_text],
  670. api_name="run",
  671. )
  672. if __name__ == "__main__":
  673. demo.queue().launch()