evaluate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import contextlib
  8. import logging
  9. import subprocess
  10. from dataclasses import dataclass
  11. from pathlib import Path
  12. from typing import Dict, List, Optional, Tuple
  13. import numpy as np
  14. import torch
  15. import torchaudio
  16. from fairseq2.data import Collater, CString, DataPipeline, FileMapper
  17. from fairseq2.data.audio import (
  18. AudioDecoder,
  19. WaveformToFbankConverter,
  20. WaveformToFbankOutput,
  21. )
  22. from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
  23. from fairseq2.data.typing import PathLike, StringLike
  24. from fairseq2.generation import SequenceGeneratorOptions
  25. from fairseq2.typing import DataType, Device
  26. from sacrebleu.metrics import BLEU # type: ignore[attr-defined]
  27. from torch import Tensor
  28. from tqdm import tqdm
  29. from seamless_communication.cli.m4t.predict import (
  30. add_inference_arguments,
  31. set_generation_opts,
  32. )
  33. from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
  34. from seamless_communication.models.unity import load_unity_text_tokenizer
  35. logging.basicConfig(
  36. level=logging.INFO,
  37. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  38. )
  39. logger = logging.getLogger(__name__)
  40. @dataclass
  41. class EvalContext:
  42. task: str
  43. """String representing the task. Valid choices are
  44. "S2ST", "S2TT", "T2ST", "T2TT", "ASR"."""
  45. output_modality: Modality
  46. """The output modality of the task."""
  47. model_name: str
  48. """The name of the S2T UnitY model."""
  49. data_file: Path
  50. """The pathname of the test TSV data file."""
  51. audio_root_dir: Optional[Path]
  52. """The pathname of the directory under which
  53. audio files are stored."""
  54. target_lang: str
  55. """The target translation language."""
  56. source_lang: Optional[str]
  57. """The source language."""
  58. batch_size: int
  59. """The batch size for model input."""
  60. device: Device
  61. """The device on which to run inference."""
  62. dtype: DataType
  63. """The data type with which to run inference."""
  64. output_path: Path
  65. """The pathname of the output directory to save
  66. the evaluation results."""
  67. ref_field: str
  68. """The reference target text field to compute
  69. the BLEU score against."""
  70. text_generation_opts: SequenceGeneratorOptions
  71. """Text generation hyperparameters."""
  72. unit_generation_opts: Optional[SequenceGeneratorOptions]
  73. """Unit generation hyperparameters, not applicable
  74. for the NAR T2U decoder."""
  75. unit_generation_ngram_filtering: bool
  76. """If True, removes consecutive repeating ngrams
  77. from the decoded unit output."""
  78. gcmvn_stats: Optional[PathLike] = None
  79. """the stats for gcmvn, used by Prosody Encoder"""
  80. def count_lines(filename: Path) -> int:
  81. result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
  82. return int(result.stdout.decode().split()[0])
  83. def build_data_pipeline(
  84. ctx: EvalContext,
  85. text_tokenizer: TextTokenizer,
  86. ) -> DataPipeline:
  87. with open(ctx.data_file, "r") as f:
  88. header = f.readline().strip("\n").split("\t")
  89. # TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
  90. n_parallel = 4
  91. split_tsv = StrSplitter(names=header)
  92. if ctx.gcmvn_stats is not None:
  93. if isinstance(ctx.gcmvn_stats, CString):
  94. ctx.gcmvn_stats = str(ctx.gcmvn_stats)
  95. gcmvn_stats: Dict[str, np.ndarray] = np.load(ctx.gcmvn_stats) # type: ignore[type-arg]
  96. gcmvn_mean = torch.tensor(
  97. gcmvn_stats["mean"], device=ctx.device, dtype=ctx.dtype
  98. )
  99. gcmvn_std = torch.tensor(gcmvn_stats["std"], device=ctx.device, dtype=ctx.dtype)
  100. pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
  101. assert ctx.audio_root_dir is not None
  102. map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
  103. pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
  104. decode_audio = AudioDecoder(dtype=torch.float32, device=ctx.device)
  105. convert_to_fbank = WaveformToFbankConverter(
  106. num_mel_bins=80,
  107. waveform_scale=2**15,
  108. channel_last=True,
  109. standardize=False,
  110. device=ctx.device,
  111. dtype=ctx.dtype,
  112. )
  113. def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
  114. fbank = data["fbank"]
  115. std, mean = torch.std_mean(fbank, dim=0)
  116. data["fbank"] = fbank.subtract(mean).divide(std)
  117. if ctx.gcmvn_stats is not None:
  118. data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
  119. return data
  120. pipeline_builder.map(
  121. [decode_audio, convert_to_fbank, normalize_fbank],
  122. selector="audio.data",
  123. num_parallel_calls=n_parallel,
  124. )
  125. pipeline_builder.bucket(bucket_size=ctx.batch_size)
  126. collate = Collater(pad_value=0, pad_to_multiple=1)
  127. pipeline_builder.map(collate, num_parallel_calls=n_parallel)
  128. pipeline_builder.prefetch(4)
  129. return pipeline_builder.and_return()
  130. def adjust_output_for_corrupted_inputs(
  131. valid_sequences: Tensor,
  132. text_output: List[StringLike],
  133. speech_output: Optional[BatchedSpeechOutput],
  134. ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
  135. adjusted_text_output: List[StringLike] = []
  136. adjusted_speech_output: Optional[BatchedSpeechOutput] = None
  137. if speech_output is not None:
  138. assert (
  139. len(text_output)
  140. == len(speech_output.units)
  141. == len(speech_output.audio_wavs)
  142. )
  143. adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
  144. batch_counter = 0
  145. for is_valid in valid_sequences:
  146. if is_valid:
  147. adjusted_text_output.append(text_output[batch_counter])
  148. if speech_output is not None:
  149. assert adjusted_speech_output is not None
  150. adjusted_speech_output.units.append(speech_output.units[batch_counter])
  151. adjusted_speech_output.audio_wavs.append(
  152. speech_output.audio_wavs[batch_counter]
  153. )
  154. batch_counter += 1
  155. else:
  156. # For the corrupted inputs, we save the following dummy outputs:
  157. # empty string for text, empty list for units, 1 second of silence for audio.
  158. adjusted_text_output.append("")
  159. if adjusted_speech_output is not None:
  160. sample_rate = adjusted_speech_output.sample_rate
  161. adjusted_speech_output.units.append([])
  162. adjusted_speech_output.audio_wavs.append(
  163. torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
  164. )
  165. return (
  166. adjusted_text_output,
  167. adjusted_speech_output,
  168. )
  169. def run_eval(
  170. translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
  171. ) -> None:
  172. pipeline = build_data_pipeline(ctx, text_tokenizer)
  173. total_steps = count_lines(ctx.data_file) - 1
  174. progress_bar = tqdm(total=total_steps)
  175. output_path = ctx.output_path / ctx.data_file.stem
  176. output_path.mkdir(parents=True, exist_ok=True)
  177. if ctx.output_modality == Modality.SPEECH:
  178. waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
  179. waveforms_dir.mkdir(parents=True, exist_ok=True)
  180. hyps = []
  181. refs = []
  182. with contextlib.ExitStack() as stack:
  183. hyp_file = stack.enter_context(
  184. open(output_path / f"text_output-{ctx.data_file.stem}.txt", "w")
  185. )
  186. if ctx.output_modality == Modality.SPEECH:
  187. unit_file = stack.enter_context(
  188. open(output_path / f"unit_output-{ctx.data_file.stem}.txt", "w")
  189. )
  190. sample_id = 0
  191. for example in pipeline:
  192. valid_sequences: Optional[Tensor] = None
  193. src = example["audio"]["data"]["fbank"]
  194. # Skip corrupted audio tensors.
  195. valid_sequences = ~torch.any(
  196. torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
  197. )
  198. if not valid_sequences.all():
  199. logger.warning(
  200. f"Sample IDs {sample_id} to {sample_id + ctx.batch_size} has some corrupted input."
  201. )
  202. src["seqs"] = src["seqs"][valid_sequences]
  203. src["seq_lens"] = src["seq_lens"][valid_sequences]
  204. # Skip performing inference when the input is entirely corrupted.
  205. if src["seqs"].numel() > 0:
  206. (
  207. text_output,
  208. speech_output,
  209. ) = translator.predict(
  210. src,
  211. ctx.task,
  212. ctx.target_lang,
  213. src_lang=ctx.source_lang,
  214. text_generation_opts=ctx.text_generation_opts,
  215. unit_generation_opts=ctx.unit_generation_opts,
  216. unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
  217. gcmvn_fbank=example["audio"]["data"].get("gcmvn_fbank", None),
  218. )
  219. else:
  220. text_output = []
  221. if ctx.output_modality == Modality.SPEECH:
  222. speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
  223. else:
  224. speech_output = None
  225. if valid_sequences is not None and not valid_sequences.all():
  226. (
  227. text_output,
  228. speech_output,
  229. ) = adjust_output_for_corrupted_inputs(
  230. valid_sequences,
  231. text_output,
  232. speech_output,
  233. )
  234. hyps += [str(s) for s in text_output]
  235. refs += [str(s) for s in example[ctx.ref_field]]
  236. for i in range(len(text_output)):
  237. t = text_output[i]
  238. hyp_file.write(f"{t}\n")
  239. if ctx.output_modality == Modality.SPEECH:
  240. assert speech_output is not None
  241. u = speech_output.units[i]
  242. str_units = [str(i) for i in u]
  243. unit_file.write(" ".join(str_units) + "\n")
  244. torchaudio.save(
  245. waveforms_dir / f"{sample_id}_pred.wav",
  246. speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
  247. sample_rate=speech_output.sample_rate,
  248. )
  249. sample_id += 1
  250. progress_bar.update(1)
  251. progress_bar.close()
  252. logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
  253. assert len(hyps) == len(refs)
  254. if len(hyps) > 0:
  255. if ctx.target_lang in ("cmn", "jpn", "lao", "mya", "tha"):
  256. tokenizer = "char"
  257. else:
  258. tokenizer = "13a"
  259. bleu = BLEU(tokenize=tokenizer)
  260. score = bleu.corpus_score(hyps, [refs])
  261. bleu_filename = output_path / f"{ctx.data_file.stem}_text_output_bleu.json"
  262. with open(bleu_filename, "w") as f:
  263. f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
  264. logger.info(score.format(signature=bleu.get_signature()))
  265. def main() -> None:
  266. parser = argparse.ArgumentParser(
  267. description="Expressivity evaluation for tasks supported by Translator."
  268. )
  269. parser.add_argument("data_file", type=str, help="Data file (.tsv) to be evaluated.")
  270. parser = add_inference_arguments(parser)
  271. parser.add_argument(
  272. "--batch_size",
  273. type=int,
  274. help="Inference batch size.",
  275. default=4,
  276. )
  277. parser.add_argument(
  278. "--audio_root_dir",
  279. type=str,
  280. help="Root directory for the audio filenames in the data file.",
  281. default="",
  282. )
  283. parser.add_argument(
  284. "--ref_field",
  285. type=str,
  286. help="Reference target text field to compute the BLEU score against.",
  287. default="tgt_text",
  288. )
  289. parser.add_argument(
  290. "--gcmvn_stats",
  291. type=str,
  292. help="The path to gcmvn fbank stats, if provided, the DataPipeline'd have another copy of gcmvn fbank features (for P2V enc)",
  293. default=None,
  294. )
  295. args = parser.parse_args()
  296. input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
  297. if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
  298. raise ValueError(
  299. f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
  300. )
  301. if torch.cuda.is_available():
  302. device = torch.device("cuda:0")
  303. dtype = torch.float32
  304. else:
  305. device = torch.device("cpu")
  306. dtype = torch.float32
  307. text_tokenizer = load_unity_text_tokenizer(args.model_name)
  308. # TODO: Avoid loading the T2U model, vocoder when the output
  309. # modality is text.
  310. translator = Translator(
  311. args.model_name,
  312. args.vocoder_name,
  313. device,
  314. text_tokenizer=text_tokenizer,
  315. dtype=dtype,
  316. )
  317. text_generation_opts, unit_generation_opts = set_generation_opts(args)
  318. logger.info(f"{text_generation_opts=}")
  319. logger.info(f"{unit_generation_opts=}")
  320. logger.info(
  321. f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
  322. )
  323. # fmt: off
  324. ctx = EvalContext(
  325. task=args.task,
  326. output_modality=output_modality,
  327. model_name=args.model_name,
  328. data_file=Path(args.data_file),
  329. audio_root_dir=Path(args.audio_root_dir),
  330. target_lang=args.tgt_lang,
  331. source_lang=args.src_lang,
  332. batch_size=args.batch_size,
  333. device=device,
  334. dtype=dtype,
  335. ref_field=args.ref_field,
  336. text_generation_opts=text_generation_opts,
  337. unit_generation_opts=unit_generation_opts,
  338. unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
  339. output_path=Path(args.output_path),
  340. gcmvn_stats=args.gcmvn_stats,
  341. )
  342. # fmt: on
  343. logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
  344. run_eval(translator, text_tokenizer, ctx)
  345. if __name__ == "__main__":
  346. main()