predict.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import argparse
  6. import logging
  7. import torch
  8. import torchaudio
  9. from seamless_communication.models.inference import Translator
  10. logging.basicConfig(
  11. level=logging.INFO,
  12. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  13. )
  14. logger = logging.getLogger(__name__)
  15. def main():
  16. parser = argparse.ArgumentParser(
  17. description="M4T inference on supported tasks using Translator."
  18. )
  19. parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
  20. parser.add_argument("task", type=str, help="Task type")
  21. parser.add_argument(
  22. "tgt_lang", type=str, help="Target language to translate/transcribe into."
  23. )
  24. parser.add_argument(
  25. "--src_lang",
  26. type=str,
  27. help="Source language, only required if input is text.",
  28. default=None,
  29. )
  30. parser.add_argument(
  31. "--output_path",
  32. type=str,
  33. help="Path to save the generated audio.",
  34. default=None,
  35. )
  36. parser.add_argument(
  37. "--model_name",
  38. type=str,
  39. help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
  40. default="seamlessM4T_large",
  41. )
  42. parser.add_argument(
  43. "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
  44. )
  45. parser.add_argument(
  46. "--ngram-filtering",
  47. type=bool,
  48. help="Enable ngram_repeat_block (currently hardcoded to 4, during decoding) and ngram filtering over units (postprocessing)",
  49. default=False,
  50. )
  51. args = parser.parse_args()
  52. if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
  53. raise ValueError("output_path must be provided to save the generated audio")
  54. if torch.cuda.is_available():
  55. device = torch.device("cuda:0")
  56. dtype = torch.float16
  57. logger.info(f"Running inference on the GPU in {dtype}.")
  58. else:
  59. device = torch.device("cpu")
  60. dtype = torch.float32
  61. logger.info(f"Running inference on the CPU in {dtype}.")
  62. translator = Translator(args.model_name, args.vocoder_name, device, dtype)
  63. translated_text, wav, sr = translator.predict(
  64. args.input,
  65. args.task,
  66. args.tgt_lang,
  67. src_lang=args.src_lang,
  68. ngram_filtering=args.ngram_filtering,
  69. )
  70. if wav is not None and sr is not None:
  71. logger.info(f"Saving translated audio in {args.tgt_lang}")
  72. torchaudio.save(
  73. args.output_path,
  74. wav[0].cpu(),
  75. sample_rate=sr,
  76. )
  77. logger.info(f"Translated text in {args.tgt_lang}: {translated_text}")
  78. if __name__ == "__main__":
  79. main()