| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- #
 
- # This source code is licensed under the license found in the
 
- # LICENSE file in the root directory of this source tree.
 
- import argparse
 
- import logging
 
- import torch
 
- import torchaudio
 
- from argparse import Namespace
 
- from fairseq2.generation import SequenceGeneratorOptions
 
- from seamless_communication.models.inference import (
 
-     NGramRepeatBlockProcessor,
 
-     Translator,
 
- )
 
- from typing import Tuple
 
- logging.basicConfig(
 
-     level=logging.INFO,
 
-     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
 
- )
 
- logger = logging.getLogger(__name__)
 
- def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
 
-     parser.add_argument("task", type=str, help="Task type")
 
-     parser.add_argument(
 
-         "tgt_lang", type=str, help="Target language to translate/transcribe into."
 
-     )
 
-     parser.add_argument(
 
-         "--src_lang",
 
-         type=str,
 
-         help="Source language, only required if input is text.",
 
-         default=None,
 
-     )
 
-     parser.add_argument(
 
-         "--output_path",
 
-         type=str,
 
-         help="Path to save the generated audio.",
 
-         default=None,
 
-     )
 
-     parser.add_argument(
 
-         "--model_name",
 
-         type=str,
 
-         help=(
 
-             "Base model name (`seamlessM4T_medium`, "
 
-             "`seamlessM4T_large`, `seamlessM4T_v2_large`)"
 
-         ),
 
-         default="seamlessM4T_v2_large",
 
-     )
 
-     parser.add_argument(
 
-         "--vocoder_name",
 
-         type=str,
 
-         help="Vocoder model name",
 
-         default="vocoder_commercial",
 
-     )
 
-     # Text generation args.
 
-     parser.add_argument(
 
-         "--text_generation_beam_size",
 
-         type=int,
 
-         help="Beam size for incremental text decoding.",
 
-         default=5,
 
-     )
 
-     parser.add_argument(
 
-         "--text_generation_max_len_a",
 
-         type=int,
 
-         help="`a` in `ax + b` for incremental text decoding.",
 
-         default=1,
 
-     )
 
-     parser.add_argument(
 
-         "--text_generation_max_len_b",
 
-         type=int,
 
-         help="`b` in `ax + b` for incremental text decoding.",
 
-         default=200,
 
-     )
 
-     parser.add_argument(
 
-         "--text_generation_ngram_blocking",
 
-         type=bool,
 
-         help=(
 
-             "Enable ngram_repeat_block for incremental text decoding."
 
-             "This blocks hypotheses with repeating ngram tokens."
 
-         ),
 
-         default=False,
 
-     )
 
-     parser.add_argument(
 
-         "--no_repeat_ngram_size",
 
-         type=int,
 
-         help="Size of ngram repeat block for both text & unit decoding.",
 
-         default=4,
 
-     )
 
-     # Unit generation args.
 
-     parser.add_argument(
 
-         "--unit_generation_beam_size",
 
-         type=int,
 
-         help=(
 
-             "Beam size for incremental unit decoding"
 
-             "not applicable for the NAR T2U decoder."
 
-         ),
 
-         default=5,
 
-     )
 
-     parser.add_argument(
 
-         "--unit_generation_max_len_a",
 
-         type=int,
 
-         help=(
 
-             "`a` in `ax + b` for incremental unit decoding"
 
-             "not applicable for the NAR T2U decoder."
 
-         ),
 
-         default=25,
 
-     )
 
-     parser.add_argument(
 
-         "--unit_generation_max_len_b",
 
-         type=int,
 
-         help=(
 
-             "`b` in `ax + b` for incremental unit decoding"
 
-             "not applicable for the NAR T2U decoder."
 
-         ),
 
-         default=50,
 
-     )
 
-     parser.add_argument(
 
-         "--unit_generation_ngram_blocking",
 
-         type=bool,
 
-         help=(
 
-             "Enable ngram_repeat_block for incremental unit decoding."
 
-             "This blocks hypotheses with repeating ngram tokens."
 
-         ),
 
-         default=False,
 
-     )
 
-     parser.add_argument(
 
-         "--unit_generation_ngram_filtering",
 
-         type=bool,
 
-         help=(
 
-             "If True, removes consecutive repeated ngrams"
 
-             "from the decoded unit output."
 
-         ),
 
-         default=False,
 
-     )
 
-     return parser
 
- def set_generation_opts(
 
-     args: Namespace,
 
- ) -> Tuple[SequenceGeneratorOptions, SequenceGeneratorOptions]:
 
-     # Set text, unit generation opts.
 
-     text_generation_opts = SequenceGeneratorOptions(
 
-         beam_size=args.text_generation_beam_size,
 
-         soft_max_seq_len=(
 
-             args.text_generation_max_len_a,
 
-             args.text_generation_max_len_b,
 
-         ),
 
-     )
 
-     if args.text_generation_ngram_blocking:
 
-         text_generation_opts.logits_processor = NGramRepeatBlockProcessor(
 
-             no_repeat_ngram_size=args.no_repeat_ngram_size
 
-         )
 
-     unit_generation_opts = SequenceGeneratorOptions(
 
-         beam_size=args.unit_generation_beam_size,
 
-         soft_max_seq_len=(
 
-             args.unit_generation_max_len_a,
 
-             args.unit_generation_max_len_b,
 
-         ),
 
-     )
 
-     if args.unit_generation_ngram_blocking:
 
-         unit_generation_opts.logits_processor = NGramRepeatBlockProcessor(
 
-             no_repeat_ngram_size=args.no_repeat_ngram_size
 
-         )
 
-     return text_generation_opts, unit_generation_opts
 
- def main():
 
-     parser = argparse.ArgumentParser(
 
-         description="M4T inference on supported tasks using Translator."
 
-     )
 
-     parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
 
-     parser = add_inference_arguments(parser)
 
-     args = parser.parse_args()
 
-     if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
 
-         raise ValueError("output_path must be provided to save the generated audio")
 
-     if torch.cuda.is_available():
 
-         device = torch.device("cuda:0")
 
-         dtype = torch.float16
 
-     else:
 
-         device = torch.device("cpu")
 
-         dtype = torch.float32
 
-     logger.info(f"Running inference on {device=} with {dtype=}.")
 
-     translator = Translator(args.model_name, args.vocoder_name, device, dtype=dtype)
 
-     text_generation_opts, unit_generation_opts = set_generation_opts(args)
 
-     logger.info(f"{text_generation_opts=}")
 
-     logger.info(f"{unit_generation_opts=}")
 
-     logger.info(
 
-         f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
 
-     )
 
-     text_output, speech_output = translator.predict(
 
-         args.input,
 
-         args.task,
 
-         args.tgt_lang,
 
-         src_lang=args.src_lang,
 
-         text_generation_opts=text_generation_opts,
 
-         unit_generation_opts=unit_generation_opts,
 
-         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
 
-     )
 
-     if speech_output is not None:
 
-         logger.info(f"Saving translated audio in {args.tgt_lang}")
 
-         torchaudio.save(
 
-             args.output_path,
 
-             speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
 
-             sample_rate=speech_output.sample_rate,
 
-         )
 
-     logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")
 
- if __name__ == "__main__":
 
-     main()
 
 
  |