123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # MIT_LICENSE file in the root directory of this source tree.
- import argparse
- import contextlib
- import logging
- from argparse import Namespace
- from pathlib import Path
- from typing import Optional
- import pandas as pd
- import torch
- import torchaudio
- from fairseq2.data import Collater, DataPipeline, FileMapper
- from fairseq2.data.audio import (
- AudioDecoder,
- WaveformToFbankConverter,
- WaveformToFbankOutput,
- )
- from fairseq2.data.text import StrSplitter, read_text
- from fairseq2.typing import DataType, Device
- from torch import Tensor
- from tqdm import tqdm
- from seamless_communication.cli.m4t.evaluate.evaluate import (
- adjust_output_for_corrupted_inputs,
- count_lines,
- )
- from seamless_communication.cli.m4t.predict import (
- add_inference_arguments,
- set_generation_opts,
- )
- from seamless_communication.inference.pretssel_generator import (
- PretsselGenerator,
- )
- from seamless_communication.inference import BatchedSpeechOutput, Translator
- from seamless_communication.models.unity import (
- load_gcmvn_stats,
- load_unity_unit_tokenizer,
- )
- from seamless_communication.store import add_gated_assets
- logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
- )
- logger = logging.getLogger(__name__)
- def build_data_pipeline(
- args: Namespace,
- device: Device,
- dtype: DataType,
- gcmvn_mean: Tensor,
- gcmvn_std: Tensor,
- ) -> DataPipeline:
- with open(args.data_file, "r") as f:
- header = f.readline().strip("\n").split("\t")
- assert (
- args.audio_field in header
- ), f"Input file does not contain {args.audio_field} field"
- n_parallel = 4
- split_tsv = StrSplitter(names=header)
- pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
- assert args.audio_root_dir is not None
- map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
- pipeline_builder.map(
- map_file, selector=args.audio_field, num_parallel_calls=n_parallel
- )
- decode_audio = AudioDecoder(dtype=torch.float32, device=device)
- convert_to_fbank = WaveformToFbankConverter(
- num_mel_bins=80,
- waveform_scale=2**15,
- channel_last=True,
- standardize=False,
- device=device,
- dtype=dtype,
- )
- def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
- fbank = data["fbank"]
- std, mean = torch.std_mean(fbank, dim=0)
- data["fbank"] = fbank.subtract(mean).divide(std)
- data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
- return data
- pipeline_builder.map(
- [decode_audio, convert_to_fbank, normalize_fbank],
- selector=f"{args.audio_field}.data",
- num_parallel_calls=n_parallel,
- )
- pipeline_builder.bucket(bucket_size=args.batch_size)
- collate = Collater(pad_value=0, pad_to_multiple=1)
- pipeline_builder.map(collate, num_parallel_calls=n_parallel)
- pipeline_builder.prefetch(4)
- return pipeline_builder.and_return()
- def main() -> None:
- parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference")
- parser.add_argument(
- "data_file", type=Path, help="Data file (.tsv) to be evaluated."
- )
- parser = add_inference_arguments(parser)
- parser.add_argument(
- "--gated-model-dir",
- type=Path,
- required=False,
- help="SeamlessExpressive model directory.",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- help="Inference batch size.",
- default=4,
- )
- parser.add_argument(
- "--audio_root_dir",
- type=Path,
- help="Root directory for the audio filenames in the data file.",
- default="",
- )
- parser.add_argument(
- "--audio_field",
- type=str,
- help="Field that includes the input audio file paths.",
- default="src_audio",
- )
- parser.add_argument(
- "--ref_field",
- type=str,
- help="Reference target text field to compute the BLEU score against.",
- default=None,
- )
- parser.add_argument(
- "--duration_factor",
- type=float,
- help="The duration factor for NAR T2U model.",
- default=1.0,
- )
- parser.add_argument(
- "--output_result_tsv",
- type=bool,
- help="Whether to output results in tsv format (for full-blown evaluation)",
- default=True,
- )
- args = parser.parse_args()
- if args.gated_model_dir:
- add_gated_assets(args.gated_model_dir)
- if torch.cuda.is_available():
- device = torch.device("cuda:0")
- dtype = torch.float16
- else:
- device = torch.device("cpu")
- dtype = torch.float32
- unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
- _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
- gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
- gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
- pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
- translator = Translator(
- args.model_name,
- vocoder_name_or_card=None,
- device=device,
- dtype=dtype,
- )
- text_generation_opts, unit_generation_opts = set_generation_opts(args)
- logger.info(f"{text_generation_opts=}")
- logger.info(f"{unit_generation_opts=}")
- logger.info(
- f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
- )
- pretssel_generator = PretsselGenerator(
- args.vocoder_name,
- vocab_info=unit_tokenizer.vocab_info,
- device=device,
- dtype=dtype,
- )
- total_steps = count_lines(args.data_file) - 1
- progress_bar = tqdm(total=total_steps)
- output_path = args.output_path / args.data_file.stem
- output_path.mkdir(parents=True, exist_ok=True)
- waveforms_dir = output_path / "waveform"
- waveforms_dir.mkdir(parents=True, exist_ok=True)
- hyps = []
- refs = []
- audio_hyps = []
- with contextlib.ExitStack() as stack:
- hyp_file = stack.enter_context(
- open(output_path / f"text_output-{args.data_file.stem}.txt", "w")
- )
- unit_file = stack.enter_context(
- open(output_path / f"unit_output-{args.data_file.stem}.txt", "w")
- )
- sample_id = 0
- for example in pipeline:
- valid_sequences: Optional[Tensor] = None
- src = example[args.audio_field]["data"]["fbank"]
- # Skip corrupted audio tensors.
- valid_sequences = ~torch.any(
- torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
- )
- if not valid_sequences.all():
- logger.warning(
- f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input."
- )
- src["seqs"] = src["seqs"][valid_sequences]
- src["seq_lens"] = src["seq_lens"][valid_sequences]
- # Skip performing inference when the input is entirely corrupted.
- if src["seqs"].numel() > 0:
- prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
- text_output, unit_output = translator.predict(
- src,
- "s2st",
- args.tgt_lang,
- src_lang=args.src_lang,
- text_generation_opts=text_generation_opts,
- unit_generation_opts=unit_generation_opts,
- unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
- duration_factor=args.duration_factor,
- prosody_encoder_input=prosody_encoder_input,
- )
- assert unit_output is not None
- speech_output = pretssel_generator.predict(
- unit_output.units,
- tgt_lang=args.tgt_lang,
- prosody_encoder_input=prosody_encoder_input,
- )
- else:
- text_output = []
- speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
- if valid_sequences is not None and not valid_sequences.all():
- text_output, speech_output = adjust_output_for_corrupted_inputs( # type: ignore[assignment]
- valid_sequences,
- text_output,
- speech_output,
- )
- hyps += [str(s) for s in text_output]
- if args.ref_field is not None and args.ref_field in example:
- refs += [str(s) for s in example[args.ref_field]]
- for i in range(len(text_output)):
- t = text_output[i]
- idx = str(example["id"][i])
- hyp_file.write(f"{t}\n")
- u = speech_output.units[i]
- str_units = [str(i) for i in u]
- unit_file.write(" ".join(str_units) + "\n")
- torchaudio.save(
- waveforms_dir / f"{idx}_pred.wav",
- speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
- sample_rate=speech_output.sample_rate,
- )
- audio_hyps.append((waveforms_dir / f"{idx}_pred.wav").as_posix())
- sample_id += 1
- progress_bar.update(1)
- progress_bar.close()
- logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
- if args.output_result_tsv:
- output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
- output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
- text_out = []
- with open(hyp_file.name) as file:
- for line in file:
- text_out.append(line.strip())
- unit_out = []
- with open(unit_file.name) as file:
- for line in file:
- unit_out.append(line.strip())
- output_tsv["hypo_audio"] = audio_hyps
- output_tsv["s2t_out"] = text_out
- output_tsv["orig_unit"] = unit_out
- output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
- logger.info(f"Output results in {output_tsv_file}")
- if __name__ == "__main__":
- main()
|