app.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. #!/usr/bin/env python
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the
  6. # MIT_LICENSE file in the root directory of this source tree.
  7. from __future__ import annotations
  8. import os
  9. import pathlib
  10. import getpass
  11. import gradio as gr
  12. import numpy as np
  13. import torch
  14. import torchaudio
  15. from fairseq2.assets import InProcAssetMetadataProvider, asset_store
  16. from huggingface_hub import snapshot_download
  17. from seamless_communication.inference import Translator
  18. from lang_list import (
  19. ASR_TARGET_LANGUAGE_NAMES,
  20. LANGUAGE_NAME_TO_CODE,
  21. S2ST_TARGET_LANGUAGE_NAMES,
  22. S2TT_TARGET_LANGUAGE_NAMES,
  23. T2ST_TARGET_LANGUAGE_NAMES,
  24. T2TT_TARGET_LANGUAGE_NAMES,
  25. TEXT_SOURCE_LANGUAGE_NAMES,
  26. )
  27. user = getpass.getuser() # this is not portable on windows
  28. CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", f"/home/{user}/app/models"))
  29. if not CHECKPOINTS_PATH.exists():
  30. snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
  31. asset_store.env_resolvers.clear()
  32. asset_store.env_resolvers.append(lambda: "demo")
  33. demo_metadata = [
  34. {
  35. "name": "seamlessM4T_v2_large@demo",
  36. "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
  37. "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
  38. },
  39. {
  40. "name": "vocoder_v2@demo",
  41. "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
  42. },
  43. ]
  44. asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
  45. DESCRIPTION = """\
  46. # SeamlessM4T
  47. [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
  48. translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
  49. This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
  50. translation and more, without relying on multiple separate models.
  51. """
  52. CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
  53. AUDIO_SAMPLE_RATE = 16000.0
  54. MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
  55. DEFAULT_TARGET_LANGUAGE = "French"
  56. if torch.cuda.is_available():
  57. device = torch.device("cuda:0")
  58. dtype = torch.float16
  59. else:
  60. device = torch.device("cpu")
  61. dtype = torch.float32
  62. translator = Translator(
  63. model_name_or_card="seamlessM4T_v2_large",
  64. vocoder_name_or_card="vocoder_v2",
  65. device=device,
  66. dtype=dtype,
  67. apply_mintox=True,
  68. )
  69. def preprocess_audio(input_audio: str) -> None:
  70. arr, org_sr = torchaudio.load(input_audio)
  71. new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
  72. max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
  73. if new_arr.shape[1] > max_length:
  74. new_arr = new_arr[:, :max_length]
  75. gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
  76. torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
  77. def run_s2st(
  78. input_audio: str, source_language: str, target_language: str
  79. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  80. preprocess_audio(input_audio)
  81. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  82. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  83. out_texts, out_audios = translator.predict(
  84. input=input_audio,
  85. task_str="S2ST",
  86. src_lang=source_language_code,
  87. tgt_lang=target_language_code,
  88. )
  89. out_text = str(out_texts[0])
  90. out_wav = out_audios.audio_wavs[0].cpu().detach().numpy()
  91. return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
  92. def run_s2tt(input_audio: str, source_language: str, target_language: str) -> str:
  93. preprocess_audio(input_audio)
  94. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  95. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  96. out_texts, _ = translator.predict(
  97. input=input_audio,
  98. task_str="S2TT",
  99. src_lang=source_language_code,
  100. tgt_lang=target_language_code,
  101. )
  102. return str(out_texts[0])
  103. def run_t2st(input_text: str, source_language: str, target_language: str) -> tuple[tuple[int, np.ndarray] | None, str]:
  104. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  105. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  106. out_texts, out_audios = translator.predict(
  107. input=input_text,
  108. task_str="T2ST",
  109. src_lang=source_language_code,
  110. tgt_lang=target_language_code,
  111. )
  112. out_text = str(out_texts[0])
  113. out_wav = out_audios.audio_wavs[0].cpu().detach().numpy()
  114. return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
  115. def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
  116. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  117. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  118. out_texts, _ = translator.predict(
  119. input=input_text,
  120. task_str="T2TT",
  121. src_lang=source_language_code,
  122. tgt_lang=target_language_code,
  123. )
  124. return str(out_texts[0])
  125. def run_asr(input_audio: str, target_language: str) -> str:
  126. preprocess_audio(input_audio)
  127. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  128. out_texts, _ = translator.predict(
  129. input=input_audio,
  130. task_str="ASR",
  131. src_lang=target_language_code,
  132. tgt_lang=target_language_code,
  133. )
  134. return str(out_texts[0])
  135. with gr.Blocks() as demo_s2st:
  136. with gr.Row():
  137. with gr.Column():
  138. with gr.Group():
  139. input_audio = gr.Audio(label="Input speech", type="filepath")
  140. source_language = gr.Dropdown(
  141. label="Source language",
  142. choices=ASR_TARGET_LANGUAGE_NAMES,
  143. value="English",
  144. )
  145. target_language = gr.Dropdown(
  146. label="Target language",
  147. choices=S2ST_TARGET_LANGUAGE_NAMES,
  148. value=DEFAULT_TARGET_LANGUAGE,
  149. )
  150. btn = gr.Button("Translate")
  151. with gr.Column():
  152. with gr.Group():
  153. output_audio = gr.Audio(
  154. label="Translated speech",
  155. autoplay=False,
  156. streaming=False,
  157. type="numpy",
  158. )
  159. output_text = gr.Textbox(label="Translated text")
  160. gr.Examples(
  161. examples=[],
  162. inputs=[input_audio, source_language, target_language],
  163. outputs=[output_audio, output_text],
  164. fn=run_s2st,
  165. cache_examples=CACHE_EXAMPLES,
  166. api_name=False,
  167. )
  168. btn.click(
  169. fn=run_s2st,
  170. inputs=[input_audio, source_language, target_language],
  171. outputs=[output_audio, output_text],
  172. api_name="s2st",
  173. )
  174. with gr.Blocks() as demo_s2tt:
  175. with gr.Row():
  176. with gr.Column():
  177. with gr.Group():
  178. input_audio = gr.Audio(label="Input speech", type="filepath")
  179. source_language = gr.Dropdown(
  180. label="Source language",
  181. choices=ASR_TARGET_LANGUAGE_NAMES,
  182. value="English",
  183. )
  184. target_language = gr.Dropdown(
  185. label="Target language",
  186. choices=S2TT_TARGET_LANGUAGE_NAMES,
  187. value=DEFAULT_TARGET_LANGUAGE,
  188. )
  189. btn = gr.Button("Translate")
  190. with gr.Column():
  191. output_text = gr.Textbox(label="Translated text")
  192. gr.Examples(
  193. examples=[],
  194. inputs=[input_audio, source_language, target_language],
  195. outputs=output_text,
  196. fn=run_s2tt,
  197. cache_examples=CACHE_EXAMPLES,
  198. api_name=False,
  199. )
  200. btn.click(
  201. fn=run_s2tt,
  202. inputs=[input_audio, source_language, target_language],
  203. outputs=output_text,
  204. api_name="s2tt",
  205. )
  206. with gr.Blocks() as demo_t2st:
  207. with gr.Row():
  208. with gr.Column():
  209. with gr.Group():
  210. input_text = gr.Textbox(label="Input text")
  211. with gr.Row():
  212. source_language = gr.Dropdown(
  213. label="Source language",
  214. choices=TEXT_SOURCE_LANGUAGE_NAMES,
  215. value="English",
  216. )
  217. target_language = gr.Dropdown(
  218. label="Target language",
  219. choices=T2ST_TARGET_LANGUAGE_NAMES,
  220. value=DEFAULT_TARGET_LANGUAGE,
  221. )
  222. btn = gr.Button("Translate")
  223. with gr.Column():
  224. with gr.Group():
  225. output_audio = gr.Audio(
  226. label="Translated speech",
  227. autoplay=False,
  228. streaming=False,
  229. type="numpy",
  230. )
  231. output_text = gr.Textbox(label="Translated text")
  232. gr.Examples(
  233. examples=[],
  234. inputs=[input_text, source_language, target_language],
  235. outputs=[output_audio, output_text],
  236. fn=run_t2st,
  237. cache_examples=CACHE_EXAMPLES,
  238. api_name=False,
  239. )
  240. gr.on(
  241. triggers=[input_text.submit, btn.click],
  242. fn=run_t2st,
  243. inputs=[input_text, source_language, target_language],
  244. outputs=[output_audio, output_text],
  245. api_name="t2st",
  246. )
  247. with gr.Blocks() as demo_t2tt:
  248. with gr.Row():
  249. with gr.Column():
  250. with gr.Group():
  251. input_text = gr.Textbox(label="Input text")
  252. with gr.Row():
  253. source_language = gr.Dropdown(
  254. label="Source language",
  255. choices=TEXT_SOURCE_LANGUAGE_NAMES,
  256. value="English",
  257. )
  258. target_language = gr.Dropdown(
  259. label="Target language",
  260. choices=T2TT_TARGET_LANGUAGE_NAMES,
  261. value=DEFAULT_TARGET_LANGUAGE,
  262. )
  263. btn = gr.Button("Translate")
  264. with gr.Column():
  265. output_text = gr.Textbox(label="Translated text")
  266. gr.Examples(
  267. examples=[],
  268. inputs=[input_text, source_language, target_language],
  269. outputs=output_text,
  270. fn=run_t2tt,
  271. cache_examples=CACHE_EXAMPLES,
  272. api_name=False,
  273. )
  274. gr.on(
  275. triggers=[input_text.submit, btn.click],
  276. fn=run_t2tt,
  277. inputs=[input_text, source_language, target_language],
  278. outputs=output_text,
  279. api_name="t2tt",
  280. )
  281. with gr.Blocks() as demo_asr:
  282. with gr.Row():
  283. with gr.Column():
  284. with gr.Group():
  285. input_audio = gr.Audio(label="Input speech", type="filepath")
  286. target_language = gr.Dropdown(
  287. label="Target language",
  288. choices=ASR_TARGET_LANGUAGE_NAMES,
  289. value=DEFAULT_TARGET_LANGUAGE,
  290. )
  291. btn = gr.Button("Translate")
  292. with gr.Column():
  293. output_text = gr.Textbox(label="Translated text")
  294. gr.Examples(
  295. examples=[],
  296. inputs=[input_audio, target_language],
  297. outputs=output_text,
  298. fn=run_asr,
  299. cache_examples=CACHE_EXAMPLES,
  300. api_name=False,
  301. )
  302. btn.click(
  303. fn=run_asr,
  304. inputs=[input_audio, target_language],
  305. outputs=output_text,
  306. api_name="asr",
  307. )
  308. with gr.Blocks(css="style.css") as demo:
  309. gr.Markdown(DESCRIPTION)
  310. gr.DuplicateButton(
  311. value="Duplicate Space for private use",
  312. elem_id="duplicate-button",
  313. visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
  314. )
  315. with gr.Tabs():
  316. with gr.Tab(label="S2ST"):
  317. demo_s2st.render()
  318. with gr.Tab(label="S2TT"):
  319. demo_s2tt.render()
  320. with gr.Tab(label="T2ST"):
  321. demo_t2st.render()
  322. with gr.Tab(label="T2TT"):
  323. demo_t2tt.render()
  324. with gr.Tab(label="ASR"):
  325. demo_asr.render()
  326. if __name__ == "__main__":
  327. demo.queue(max_size=50).launch()