123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import argparse
- import logging
- import torch
- import torchaudio
- from argparse import Namespace
- from fairseq2.generation import SequenceGeneratorOptions
- from seamless_communication.models.inference import (
- NGramRepeatBlockProcessor,
- Translator,
- )
- from typing import Tuple
- logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
- )
- logger = logging.getLogger(__name__)
- def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
- parser.add_argument("task", type=str, help="Task type")
- parser.add_argument(
- "tgt_lang", type=str, help="Target language to translate/transcribe into."
- )
- parser.add_argument(
- "--src_lang",
- type=str,
- help="Source language, only required if input is text.",
- default=None,
- )
- parser.add_argument(
- "--output_path",
- type=str,
- help="Path to save the generated audio.",
- default=None,
- )
- parser.add_argument(
- "--model_name",
- type=str,
- help=(
- "Base model name (`seamlessM4T_medium`, "
- "`seamlessM4T_large`, `seamlessM4T_v2_large`)"
- ),
- default="seamlessM4T_v2_large",
- )
- parser.add_argument(
- "--vocoder_name",
- type=str,
- help="Vocoder model name",
- default="vocoder_commercial",
- )
- # Text generation args.
- parser.add_argument(
- "--text_generation_beam_size",
- type=int,
- help="Beam size for incremental text decoding.",
- default=5,
- )
- parser.add_argument(
- "--text_generation_max_len_a",
- type=int,
- help="`a` in `ax + b` for incremental text decoding.",
- default=1,
- )
- parser.add_argument(
- "--text_generation_max_len_b",
- type=int,
- help="`b` in `ax + b` for incremental text decoding.",
- default=200,
- )
- parser.add_argument(
- "--text_generation_ngram_blocking",
- type=bool,
- help=(
- "Enable ngram_repeat_block for incremental text decoding."
- "This blocks hypotheses with repeating ngram tokens."
- ),
- default=False,
- )
- parser.add_argument(
- "--no_repeat_ngram_size",
- type=int,
- help="Size of ngram repeat block for both text & unit decoding.",
- default=4,
- )
- # Unit generation args.
- parser.add_argument(
- "--unit_generation_beam_size",
- type=int,
- help=(
- "Beam size for incremental unit decoding"
- "not applicable for the NAR T2U decoder."
- ),
- default=5,
- )
- parser.add_argument(
- "--unit_generation_max_len_a",
- type=int,
- help=(
- "`a` in `ax + b` for incremental unit decoding"
- "not applicable for the NAR T2U decoder."
- ),
- default=25,
- )
- parser.add_argument(
- "--unit_generation_max_len_b",
- type=int,
- help=(
- "`b` in `ax + b` for incremental unit decoding"
- "not applicable for the NAR T2U decoder."
- ),
- default=50,
- )
- parser.add_argument(
- "--unit_generation_ngram_blocking",
- type=bool,
- help=(
- "Enable ngram_repeat_block for incremental unit decoding."
- "This blocks hypotheses with repeating ngram tokens."
- ),
- default=False,
- )
- parser.add_argument(
- "--unit_generation_ngram_filtering",
- type=bool,
- help=(
- "If True, removes consecutive repeated ngrams"
- "from the decoded unit output."
- ),
- default=False,
- )
- return parser
- def set_generation_opts(
- args: Namespace,
- ) -> Tuple[SequenceGeneratorOptions, SequenceGeneratorOptions]:
- # Set text, unit generation opts.
- text_generation_opts = SequenceGeneratorOptions(
- beam_size=args.text_generation_beam_size,
- soft_max_seq_len=(
- args.text_generation_max_len_a,
- args.text_generation_max_len_b,
- ),
- )
- if args.text_generation_ngram_blocking:
- text_generation_opts.logits_processor = NGramRepeatBlockProcessor(
- no_repeat_ngram_size=args.no_repeat_ngram_size
- )
- unit_generation_opts = SequenceGeneratorOptions(
- beam_size=args.unit_generation_beam_size,
- soft_max_seq_len=(
- args.unit_generation_max_len_a,
- args.unit_generation_max_len_b,
- ),
- )
- if args.unit_generation_ngram_blocking:
- unit_generation_opts.logits_processor = NGramRepeatBlockProcessor(
- no_repeat_ngram_size=args.no_repeat_ngram_size
- )
- return text_generation_opts, unit_generation_opts
- def main():
- parser = argparse.ArgumentParser(
- description="M4T inference on supported tasks using Translator."
- )
- parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
- parser = add_inference_arguments(parser)
- args = parser.parse_args()
- if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
- raise ValueError("output_path must be provided to save the generated audio")
- if torch.cuda.is_available():
- device = torch.device("cuda:0")
- dtype = torch.float16
- else:
- device = torch.device("cpu")
- dtype = torch.float32
- logger.info(f"Running inference on {device=} with {dtype=}.")
- translator = Translator(args.model_name, args.vocoder_name, device, dtype=dtype)
- text_generation_opts, unit_generation_opts = set_generation_opts(args)
- logger.info(f"{text_generation_opts=}")
- logger.info(f"{unit_generation_opts=}")
- logger.info(
- f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
- )
- text_output, speech_output = translator.predict(
- args.input,
- args.task,
- args.tgt_lang,
- src_lang=args.src_lang,
- text_generation_opts=text_generation_opts,
- unit_generation_opts=unit_generation_opts,
- unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
- )
- if speech_output is not None:
- logger.info(f"Saving translated audio in {args.tgt_lang}")
- torchaudio.save(
- args.output_path,
- speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
- sample_rate=speech_output.sample_rate,
- )
- logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")
- if __name__ == "__main__":
- main()
|