predict.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import argparse
  6. import logging
  7. import torch
  8. import torchaudio
  9. from seamless_communication.models.inference import Translator
  10. logging.basicConfig(level=logging.INFO)
  11. logger = logging.getLogger(__name__)
  12. def main():
  13. parser = argparse.ArgumentParser(
  14. description="M4T inference on supported tasks using Translator."
  15. )
  16. parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
  17. parser.add_argument("task", type=str, help="Task type")
  18. parser.add_argument(
  19. "tgt_lang", type=str, help="Target language to translate/transcribe into."
  20. )
  21. parser.add_argument(
  22. "--src_lang",
  23. type=str,
  24. help="Source language, only required if input is text.",
  25. default=None,
  26. )
  27. parser.add_argument(
  28. "--output_path",
  29. type=str,
  30. help="Path to save the generated audio.",
  31. default=None,
  32. )
  33. parser.add_argument(
  34. "--model_name", type=str, help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)", default="seamlessM4T_large"
  35. )
  36. parser.add_argument(
  37. "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
  38. )
  39. args = parser.parse_args()
  40. if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
  41. raise ValueError("output_path must be provided to save the generated audio")
  42. if torch.cuda.is_available():
  43. device = torch.device("cuda:0")
  44. logger.info("Running inference on the GPU.")
  45. else:
  46. device = torch.device("cpu")
  47. logger.info("Running inference on the CPU.")
  48. translator = Translator(args.model_name, args.vocoder_name, device)
  49. translated_text, wav, sr = translator.predict(
  50. args.input, args.task, args.tgt_lang, src_lang=args.src_lang
  51. )
  52. if wav is not None and sr is not None:
  53. logger.info(f"Saving translated audio in {args.tgt_lang}")
  54. torchaudio.save(
  55. args.output_path,
  56. wav[0].cpu(),
  57. sample_rate=sr,
  58. )
  59. logger.info(f"Translated text in {args.tgt_lang}: {translated_text}")
  60. if __name__ == "__main__":
  61. main()