1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import argparse
- import logging
- import torch
- import torchaudio
- from seamless_communication.models.inference import Translator
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- def main():
- parser = argparse.ArgumentParser(
- description="M4T inference on supported tasks using Translator."
- )
- parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
- parser.add_argument("task", type=str, help="Task type")
- parser.add_argument(
- "tgt_lang", type=str, help="Target language to translate/transcribe into."
- )
- parser.add_argument(
- "--src_lang",
- type=str,
- help="Source language, only required if input is text.",
- default=None,
- )
- parser.add_argument(
- "--output_path",
- type=str,
- help="Path to save the generated audio.",
- default=None,
- )
- parser.add_argument(
- "--model_name",
- type=str,
- help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
- default="seamlessM4T_large",
- )
- parser.add_argument(
- "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
- )
- parser.add_argument(
- "--ngram-filtering",
- type=bool,
- help="Enable ngram_repeat_block (currently hardcoded to 4, during decoding) and ngram filtering over units (postprocessing)",
- default=False,
- )
- args = parser.parse_args()
- if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
- raise ValueError("output_path must be provided to save the generated audio")
- if torch.cuda.is_available():
- device = torch.device("cuda:0")
- logger.info("Running inference on the GPU.")
- else:
- device = torch.device("cpu")
- logger.info("Running inference on the CPU.")
- translator = Translator(args.model_name, args.vocoder_name, device)
- translated_text, wav, sr = translator.predict(
- args.input,
- args.task,
- args.tgt_lang,
- src_lang=args.src_lang,
- ngram_filtering=args.ngram_filtering,
- )
- if wav is not None and sr is not None:
- logger.info(f"Saving translated audio in {args.tgt_lang}")
- torchaudio.save(
- args.output_path,
- wav[0].cpu(),
- sample_rate=sr,
- )
- logger.info(f"Translated text in {args.tgt_lang}: {translated_text}")
- if __name__ == "__main__":
- main()
|