| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 | # Copyright (c) Meta Platforms, Inc. and affiliates.## This source code is licensed under the license found in the# LICENSE file in the root directory of this source tree.import argparseimport loggingimport torchimport torchaudiofrom seamless_communication.models.inference import Translatorlogging.basicConfig(    level=logging.INFO,    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",)logger = logging.getLogger(__name__)def main():    parser = argparse.ArgumentParser(        description="M4T inference on supported tasks using Translator."    )    parser.add_argument("input", type=str, help="Audio WAV file path or text input.")    parser.add_argument("task", type=str, help="Task type")    parser.add_argument(        "tgt_lang", type=str, help="Target language to translate/transcribe into."    )    parser.add_argument(        "--src_lang",        type=str,        help="Source language, only required if input is text.",        default=None,    )    parser.add_argument(        "--output_path",        type=str,        help="Path to save the generated audio.",        default=None,    )    parser.add_argument(        "--model_name",        type=str,        help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",        default="seamlessM4T_large",    )    parser.add_argument(        "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"    )    parser.add_argument(        "--ngram-filtering",        type=bool,        help="Enable ngram_repeat_block (currently hardcoded to 4, during decoding) and ngram filtering over units (postprocessing)",        default=False,    )    args = parser.parse_args()    if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:        raise ValueError("output_path must be provided to save the generated audio")    if torch.cuda.is_available():        device = torch.device("cuda:0")        dtype = torch.float16        logger.info(f"Running inference on the GPU in {dtype}.")    else:        device = torch.device("cpu")        dtype = torch.float32        logger.info(f"Running inference on the CPU in {dtype}.")    translator = Translator(args.model_name, args.vocoder_name, device, dtype)    translated_text, wav, sr = translator.predict(        args.input,        args.task,        args.tgt_lang,        src_lang=args.src_lang,        ngram_filtering=args.ngram_filtering,    )    if wav is not None and sr is not None:        logger.info(f"Saving translated audio in {args.tgt_lang}")        torchaudio.save(            args.output_path,            wav[0].to(torch.float32).cpu(),            sample_rate=sr,        )    logger.info(f"Translated text in {args.tgt_lang}: {translated_text}")if __name__ == "__main__":    main()
 |