| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- #
 
- # This source code is licensed under the MIT license found in the
 
- # LICENSE file in the root directory of this source tree.
 
- import argparse
 
- import logging
 
- import torch
 
- import torchaudio
 
- from seamless_communication.models.inference import Translator
 
- logging.basicConfig(
 
-     level=logging.INFO,
 
-     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
 
- )
 
- logger = logging.getLogger(__name__)
 
- def main():
 
-     parser = argparse.ArgumentParser(
 
-         description="M4T inference on supported tasks using Translator."
 
-     )
 
-     parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
 
-     parser.add_argument("task", type=str, help="Task type")
 
-     parser.add_argument(
 
-         "tgt_lang", type=str, help="Target language to translate/transcribe into."
 
-     )
 
-     parser.add_argument(
 
-         "--src_lang",
 
-         type=str,
 
-         help="Source language, only required if input is text.",
 
-         default=None,
 
-     )
 
-     parser.add_argument(
 
-         "--output_path",
 
-         type=str,
 
-         help="Path to save the generated audio.",
 
-         default=None,
 
-     )
 
-     parser.add_argument(
 
-         "--model_name",
 
-         type=str,
 
-         help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
 
-         default="seamlessM4T_large",
 
-     )
 
-     parser.add_argument(
 
-         "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
 
-     )
 
-     parser.add_argument(
 
-         "--ngram-filtering",
 
-         type=bool,
 
-         help="Enable ngram_repeat_block (currently hardcoded to 4, during decoding) and ngram filtering over units (postprocessing)",
 
-         default=False,
 
-     )
 
-     args = parser.parse_args()
 
-     if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
 
-         raise ValueError("output_path must be provided to save the generated audio")
 
-     if torch.cuda.is_available():
 
-         device = torch.device("cuda:0")
 
-         logger.info("Running inference on the GPU.")
 
-     else:
 
-         device = torch.device("cpu")
 
-         logger.info("Running inference on the CPU.")
 
-     translator = Translator(args.model_name, args.vocoder_name, device)
 
-     translated_text, wav, sr = translator.predict(
 
-         args.input,
 
-         args.task,
 
-         args.tgt_lang,
 
-         src_lang=args.src_lang,
 
-         ngram_filtering=args.ngram_filtering,
 
-     )
 
-     if wav is not None and sr is not None:
 
-         logger.info(f"Saving translated audio in {args.tgt_lang}")
 
-         torchaudio.save(
 
-             args.output_path,
 
-             wav[0].cpu(),
 
-             sample_rate=sr,
 
-         )
 
-     logger.info(f"Translated text in {args.tgt_lang}: {translated_text}")
 
- if __name__ == "__main__":
 
-     main()
 
 
  |