app.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. #!/usr/bin/env python
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the
  6. # MIT_LICENSE file in the root directory of this source tree.
  7. from __future__ import annotations
  8. import os
  9. import pathlib
  10. import gradio as gr
  11. import numpy as np
  12. import torch
  13. import torchaudio
  14. from fairseq2.assets import InProcAssetMetadataProvider, asset_store
  15. from huggingface_hub import snapshot_download
  16. from seamless_communication.inference import Translator
  17. from lang_list import (
  18. ASR_TARGET_LANGUAGE_NAMES,
  19. LANGUAGE_NAME_TO_CODE,
  20. S2ST_TARGET_LANGUAGE_NAMES,
  21. S2TT_TARGET_LANGUAGE_NAMES,
  22. T2ST_TARGET_LANGUAGE_NAMES,
  23. T2TT_TARGET_LANGUAGE_NAMES,
  24. TEXT_SOURCE_LANGUAGE_NAMES,
  25. )
  26. CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
  27. if not CHECKPOINTS_PATH.exists():
  28. snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
  29. asset_store.env_resolvers.clear()
  30. asset_store.env_resolvers.append(lambda: "demo")
  31. demo_metadata = [
  32. {
  33. "name": "seamlessM4T_v2_large@demo",
  34. "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
  35. "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
  36. },
  37. {
  38. "name": "vocoder_v2@demo",
  39. "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
  40. },
  41. ]
  42. asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
  43. DESCRIPTION = """\
  44. # SeamlessM4T
  45. [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
  46. translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
  47. This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
  48. translation and more, without relying on multiple separate models.
  49. """
  50. CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
  51. AUDIO_SAMPLE_RATE = 16000.0
  52. MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
  53. DEFAULT_TARGET_LANGUAGE = "French"
  54. if torch.cuda.is_available():
  55. device = torch.device("cuda:0")
  56. dtype = torch.float16
  57. else:
  58. device = torch.device("cpu")
  59. dtype = torch.float32
  60. translator = Translator(
  61. model_name_or_card="seamlessM4T_v2_large",
  62. vocoder_name_or_card="vocoder_v2",
  63. device=device,
  64. dtype=dtype,
  65. apply_mintox=True,
  66. )
  67. def preprocess_audio(input_audio: str) -> None:
  68. arr, org_sr = torchaudio.load(input_audio)
  69. new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
  70. max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
  71. if new_arr.shape[1] > max_length:
  72. new_arr = new_arr[:, :max_length]
  73. gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
  74. torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
  75. def run_s2st(
  76. input_audio: str, source_language: str, target_language: str
  77. ) -> tuple[tuple[int, np.ndarray] | None, str]:
  78. preprocess_audio(input_audio)
  79. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  80. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  81. out_texts, out_audios = translator.predict(
  82. input=input_audio,
  83. task_str="S2ST",
  84. src_lang=source_language_code,
  85. tgt_lang=target_language_code,
  86. )
  87. out_text = str(out_texts[0])
  88. out_wav = out_audios.audio_wavs[0].cpu().detach().numpy()
  89. return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
  90. def run_s2tt(input_audio: str, source_language: str, target_language: str) -> str:
  91. preprocess_audio(input_audio)
  92. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  93. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  94. out_texts, _ = translator.predict(
  95. input=input_audio,
  96. task_str="S2TT",
  97. src_lang=source_language_code,
  98. tgt_lang=target_language_code,
  99. )
  100. return str(out_texts[0])
  101. def run_t2st(input_text: str, source_language: str, target_language: str) -> tuple[tuple[int, np.ndarray] | None, str]:
  102. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  103. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  104. out_texts, out_audios = translator.predict(
  105. input=input_text,
  106. task_str="T2ST",
  107. src_lang=source_language_code,
  108. tgt_lang=target_language_code,
  109. )
  110. out_text = str(out_texts[0])
  111. out_wav = out_audios.audio_wavs[0].cpu().detach().numpy()
  112. return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
  113. def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
  114. source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
  115. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  116. out_texts, _ = translator.predict(
  117. input=input_text,
  118. task_str="T2TT",
  119. src_lang=source_language_code,
  120. tgt_lang=target_language_code,
  121. )
  122. return str(out_texts[0])
  123. def run_asr(input_audio: str, target_language: str) -> str:
  124. preprocess_audio(input_audio)
  125. target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
  126. out_texts, _ = translator.predict(
  127. input=input_audio,
  128. task_str="ASR",
  129. src_lang=target_language_code,
  130. tgt_lang=target_language_code,
  131. )
  132. return str(out_texts[0])
  133. with gr.Blocks() as demo_s2st:
  134. with gr.Row():
  135. with gr.Column():
  136. with gr.Group():
  137. input_audio = gr.Audio(label="Input speech", type="filepath")
  138. source_language = gr.Dropdown(
  139. label="Source language",
  140. choices=ASR_TARGET_LANGUAGE_NAMES,
  141. value="English",
  142. )
  143. target_language = gr.Dropdown(
  144. label="Target language",
  145. choices=S2ST_TARGET_LANGUAGE_NAMES,
  146. value=DEFAULT_TARGET_LANGUAGE,
  147. )
  148. btn = gr.Button("Translate")
  149. with gr.Column():
  150. with gr.Group():
  151. output_audio = gr.Audio(
  152. label="Translated speech",
  153. autoplay=False,
  154. streaming=False,
  155. type="numpy",
  156. )
  157. output_text = gr.Textbox(label="Translated text")
  158. gr.Examples(
  159. examples=[],
  160. inputs=[input_audio, source_language, target_language],
  161. outputs=[output_audio, output_text],
  162. fn=run_s2st,
  163. cache_examples=CACHE_EXAMPLES,
  164. api_name=False,
  165. )
  166. btn.click(
  167. fn=run_s2st,
  168. inputs=[input_audio, source_language, target_language],
  169. outputs=[output_audio, output_text],
  170. api_name="s2st",
  171. )
  172. with gr.Blocks() as demo_s2tt:
  173. with gr.Row():
  174. with gr.Column():
  175. with gr.Group():
  176. input_audio = gr.Audio(label="Input speech", type="filepath")
  177. source_language = gr.Dropdown(
  178. label="Source language",
  179. choices=ASR_TARGET_LANGUAGE_NAMES,
  180. value="English",
  181. )
  182. target_language = gr.Dropdown(
  183. label="Target language",
  184. choices=S2TT_TARGET_LANGUAGE_NAMES,
  185. value=DEFAULT_TARGET_LANGUAGE,
  186. )
  187. btn = gr.Button("Translate")
  188. with gr.Column():
  189. output_text = gr.Textbox(label="Translated text")
  190. gr.Examples(
  191. examples=[],
  192. inputs=[input_audio, source_language, target_language],
  193. outputs=output_text,
  194. fn=run_s2tt,
  195. cache_examples=CACHE_EXAMPLES,
  196. api_name=False,
  197. )
  198. btn.click(
  199. fn=run_s2tt,
  200. inputs=[input_audio, source_language, target_language],
  201. outputs=output_text,
  202. api_name="s2tt",
  203. )
  204. with gr.Blocks() as demo_t2st:
  205. with gr.Row():
  206. with gr.Column():
  207. with gr.Group():
  208. input_text = gr.Textbox(label="Input text")
  209. with gr.Row():
  210. source_language = gr.Dropdown(
  211. label="Source language",
  212. choices=TEXT_SOURCE_LANGUAGE_NAMES,
  213. value="English",
  214. )
  215. target_language = gr.Dropdown(
  216. label="Target language",
  217. choices=T2ST_TARGET_LANGUAGE_NAMES,
  218. value=DEFAULT_TARGET_LANGUAGE,
  219. )
  220. btn = gr.Button("Translate")
  221. with gr.Column():
  222. with gr.Group():
  223. output_audio = gr.Audio(
  224. label="Translated speech",
  225. autoplay=False,
  226. streaming=False,
  227. type="numpy",
  228. )
  229. output_text = gr.Textbox(label="Translated text")
  230. gr.Examples(
  231. examples=[],
  232. inputs=[input_text, source_language, target_language],
  233. outputs=[output_audio, output_text],
  234. fn=run_t2st,
  235. cache_examples=CACHE_EXAMPLES,
  236. api_name=False,
  237. )
  238. gr.on(
  239. triggers=[input_text.submit, btn.click],
  240. fn=run_t2st,
  241. inputs=[input_text, source_language, target_language],
  242. outputs=[output_audio, output_text],
  243. api_name="t2st",
  244. )
  245. with gr.Blocks() as demo_t2tt:
  246. with gr.Row():
  247. with gr.Column():
  248. with gr.Group():
  249. input_text = gr.Textbox(label="Input text")
  250. with gr.Row():
  251. source_language = gr.Dropdown(
  252. label="Source language",
  253. choices=TEXT_SOURCE_LANGUAGE_NAMES,
  254. value="English",
  255. )
  256. target_language = gr.Dropdown(
  257. label="Target language",
  258. choices=T2TT_TARGET_LANGUAGE_NAMES,
  259. value=DEFAULT_TARGET_LANGUAGE,
  260. )
  261. btn = gr.Button("Translate")
  262. with gr.Column():
  263. output_text = gr.Textbox(label="Translated text")
  264. gr.Examples(
  265. examples=[],
  266. inputs=[input_text, source_language, target_language],
  267. outputs=output_text,
  268. fn=run_t2tt,
  269. cache_examples=CACHE_EXAMPLES,
  270. api_name=False,
  271. )
  272. gr.on(
  273. triggers=[input_text.submit, btn.click],
  274. fn=run_t2tt,
  275. inputs=[input_text, source_language, target_language],
  276. outputs=output_text,
  277. api_name="t2tt",
  278. )
  279. with gr.Blocks() as demo_asr:
  280. with gr.Row():
  281. with gr.Column():
  282. with gr.Group():
  283. input_audio = gr.Audio(label="Input speech", type="filepath")
  284. target_language = gr.Dropdown(
  285. label="Target language",
  286. choices=ASR_TARGET_LANGUAGE_NAMES,
  287. value=DEFAULT_TARGET_LANGUAGE,
  288. )
  289. btn = gr.Button("Translate")
  290. with gr.Column():
  291. output_text = gr.Textbox(label="Translated text")
  292. gr.Examples(
  293. examples=[],
  294. inputs=[input_audio, target_language],
  295. outputs=output_text,
  296. fn=run_asr,
  297. cache_examples=CACHE_EXAMPLES,
  298. api_name=False,
  299. )
  300. btn.click(
  301. fn=run_asr,
  302. inputs=[input_audio, target_language],
  303. outputs=output_text,
  304. api_name="asr",
  305. )
  306. with gr.Blocks(css="style.css") as demo:
  307. gr.Markdown(DESCRIPTION)
  308. gr.DuplicateButton(
  309. value="Duplicate Space for private use",
  310. elem_id="duplicate-button",
  311. visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
  312. )
  313. with gr.Tabs():
  314. with gr.Tab(label="S2ST"):
  315. demo_s2st.render()
  316. with gr.Tab(label="S2TT"):
  317. demo_s2tt.render()
  318. with gr.Tab(label="T2ST"):
  319. demo_t2st.render()
  320. with gr.Tab(label="T2TT"):
  321. demo_t2tt.render()
  322. with gr.Tab(label="ASR"):
  323. demo_asr.render()
  324. if __name__ == "__main__":
  325. demo.queue(max_size=50).launch()