| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 | # Copyright (c) Meta Platforms, Inc. and affiliates.## This source code is licensed under the license found in the# LICENSE file in the root directory of this source tree.import argparseimport loggingimport torchimport torchaudiofrom argparse import Namespacefrom fairseq2.generation import SequenceGeneratorOptionsfrom seamless_communication.models.inference import (    NGramRepeatBlockProcessor,    Translator,)from typing import Tuplelogging.basicConfig(    level=logging.INFO,    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",)logger = logging.getLogger(__name__)def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:    parser.add_argument("task", type=str, help="Task type")    parser.add_argument(        "tgt_lang", type=str, help="Target language to translate/transcribe into."    )    parser.add_argument(        "--src_lang",        type=str,        help="Source language, only required if input is text.",        default=None,    )    parser.add_argument(        "--output_path",        type=str,        help="Path to save the generated audio.",        default=None,    )    parser.add_argument(        "--model_name",        type=str,        help=(            "Base model name (`seamlessM4T_medium`, "            "`seamlessM4T_large`, `seamlessM4T_v2_large`)"        ),        default="seamlessM4T_v2_large",    )    parser.add_argument(        "--vocoder_name",        type=str,        help="Vocoder model name",        default="vocoder_commercial",    )    # Text generation args.    parser.add_argument(        "--text_generation_beam_size",        type=int,        help="Beam size for incremental text decoding.",        default=5,    )    parser.add_argument(        "--text_generation_max_len_a",        type=int,        help="`a` in `ax + b` for incremental text decoding.",        default=1,    )    parser.add_argument(        "--text_generation_max_len_b",        type=int,        help="`b` in `ax + b` for incremental text decoding.",        default=200,    )    parser.add_argument(        "--text_generation_ngram_blocking",        type=bool,        help=(            "Enable ngram_repeat_block for incremental text decoding."            "This blocks hypotheses with repeating ngram tokens."        ),        default=False,    )    parser.add_argument(        "--no_repeat_ngram_size",        type=int,        help="Size of ngram repeat block for both text & unit decoding.",        default=4,    )    # Unit generation args.    parser.add_argument(        "--unit_generation_beam_size",        type=int,        help=(            "Beam size for incremental unit decoding"            "not applicable for the NAR T2U decoder."        ),        default=5,    )    parser.add_argument(        "--unit_generation_max_len_a",        type=int,        help=(            "`a` in `ax + b` for incremental unit decoding"            "not applicable for the NAR T2U decoder."        ),        default=25,    )    parser.add_argument(        "--unit_generation_max_len_b",        type=int,        help=(            "`b` in `ax + b` for incremental unit decoding"            "not applicable for the NAR T2U decoder."        ),        default=50,    )    parser.add_argument(        "--unit_generation_ngram_blocking",        type=bool,        help=(            "Enable ngram_repeat_block for incremental unit decoding."            "This blocks hypotheses with repeating ngram tokens."        ),        default=False,    )    parser.add_argument(        "--unit_generation_ngram_filtering",        type=bool,        help=(            "If True, removes consecutive repeated ngrams"            "from the decoded unit output."        ),        default=False,    )    return parserdef set_generation_opts(    args: Namespace,) -> Tuple[SequenceGeneratorOptions, SequenceGeneratorOptions]:    # Set text, unit generation opts.    text_generation_opts = SequenceGeneratorOptions(        beam_size=args.text_generation_beam_size,        soft_max_seq_len=(            args.text_generation_max_len_a,            args.text_generation_max_len_b,        ),    )    if args.text_generation_ngram_blocking:        text_generation_opts.logits_processor = NGramRepeatBlockProcessor(            no_repeat_ngram_size=args.no_repeat_ngram_size        )    unit_generation_opts = SequenceGeneratorOptions(        beam_size=args.unit_generation_beam_size,        soft_max_seq_len=(            args.unit_generation_max_len_a,            args.unit_generation_max_len_b,        ),    )    if args.unit_generation_ngram_blocking:        unit_generation_opts.logits_processor = NGramRepeatBlockProcessor(            no_repeat_ngram_size=args.no_repeat_ngram_size        )    return text_generation_opts, unit_generation_optsdef main():    parser = argparse.ArgumentParser(        description="M4T inference on supported tasks using Translator."    )    parser.add_argument("input", type=str, help="Audio WAV file path or text input.")    parser = add_inference_arguments(parser)    args = parser.parse_args()    if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:        raise ValueError("output_path must be provided to save the generated audio")    if torch.cuda.is_available():        device = torch.device("cuda:0")        dtype = torch.float16    else:        device = torch.device("cpu")        dtype = torch.float32    logger.info(f"Running inference on {device=} with {dtype=}.")    translator = Translator(args.model_name, args.vocoder_name, device, dtype=dtype)    text_generation_opts, unit_generation_opts = set_generation_opts(args)    logger.info(f"{text_generation_opts=}")    logger.info(f"{unit_generation_opts=}")    logger.info(        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"    )    text_output, speech_output = translator.predict(        args.input,        args.task,        args.tgt_lang,        src_lang=args.src_lang,        text_generation_opts=text_generation_opts,        unit_generation_opts=unit_generation_opts,        unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,    )    if speech_output is not None:        logger.info(f"Saving translated audio in {args.tgt_lang}")        torchaudio.save(            args.output_path,            speech_output.audio_wavs[0][0].to(torch.float32).cpu(),            sample_rate=speech_output.sample_rate,        )    logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")if __name__ == "__main__":    main()
 |