evaluate.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import contextlib
  8. import itertools
  9. import logging
  10. import subprocess
  11. import torch
  12. import torchaudio
  13. from dataclasses import dataclass
  14. from pathlib import Path
  15. from torch import Tensor
  16. from tqdm import tqdm
  17. from typing import List, Optional, Tuple
  18. from sacrebleu.metrics import BLEU
  19. from fairseq2.data import Collater, DataPipeline, FileMapper
  20. from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
  21. from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
  22. from fairseq2.data.typing import StringLike
  23. from fairseq2.generation import SequenceGeneratorOptions
  24. from fairseq2.typing import Device, DataType
  25. from m4t_scripts.predict import add_inference_arguments, set_generation_opts
  26. from seamless_communication.models.inference import (
  27. BatchedSpeechOutput,
  28. Modality,
  29. Translator,
  30. )
  31. from seamless_communication.models.unity import load_unity_text_tokenizer
  32. logging.basicConfig(
  33. level=logging.INFO,
  34. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  35. )
  36. logger = logging.getLogger(__name__)
  37. @dataclass
  38. class EvalContext:
  39. task: str
  40. """String representing the task. Valid choices are
  41. "S2ST", "S2TT", "T2ST", "T2TT", "ASR"."""
  42. input_modality: Modality
  43. """The input modality of the task."""
  44. output_modality: Modality
  45. """The output modality of the task."""
  46. model_name: str
  47. """The name of the S2T UnitY model."""
  48. data_file: Path
  49. """The pathname of the test TSV data file."""
  50. audio_root_dir: Path
  51. """The pathname of the directory under which
  52. audio files are stored."""
  53. target_lang: str
  54. """The target translation language."""
  55. source_lang: Optional[str]
  56. """The source language."""
  57. batch_size: int
  58. """The batch size for model input."""
  59. device: Device
  60. """The device on which to run inference."""
  61. dtype: DataType
  62. """The data type with which to run inference."""
  63. output_path: Path
  64. """The pathname of the output directory to save
  65. the evaluation results."""
  66. ref_field: str
  67. """The reference target text field to compute
  68. the BLEU score against."""
  69. text_generation_opts: SequenceGeneratorOptions
  70. """Text generation hyperparameters."""
  71. unit_generation_opts: Optional[SequenceGeneratorOptions]
  72. """Unit generation hyperparameters, not applicable
  73. for the NAR T2U decoder."""
  74. unit_generation_ngram_filtering: bool
  75. """If True, removes consecutive repeating ngrams
  76. from the decoded unit output."""
  77. def count_lines(filename: Path) -> int:
  78. result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
  79. return int(result.stdout.decode().split()[0])
  80. def build_data_pipeline(
  81. ctx: EvalContext,
  82. text_tokenizer: TextTokenizer,
  83. ) -> DataPipeline:
  84. with open(ctx.data_file, "r") as f:
  85. header = f.readline().strip("\n").split("\t")
  86. first_example = f.readline().strip("\n").split("\t")
  87. # TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
  88. n_parallel = 4
  89. split_tsv = StrSplitter(names=header)
  90. pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
  91. if ctx.input_modality == Modality.SPEECH:
  92. map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
  93. pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
  94. decode_audio = AudioDecoder(dtype=torch.float32, device=ctx.device)
  95. convert_to_fbank = WaveformToFbankConverter(
  96. num_mel_bins=80,
  97. waveform_scale=2**15,
  98. channel_last=True,
  99. standardize=True,
  100. device=ctx.device,
  101. dtype=ctx.dtype,
  102. )
  103. pipeline_builder.map(
  104. [decode_audio, convert_to_fbank],
  105. selector="audio.data",
  106. num_parallel_calls=n_parallel,
  107. )
  108. else:
  109. if "src_lang" in header:
  110. source_lang = first_example[header.index("src_lang")]
  111. ctx.source_lang = source_lang
  112. elif ctx.source_lang is None:
  113. raise ValueError(
  114. (
  115. "'src_lang' is missing in the data_file"
  116. "header and in the arguments."
  117. )
  118. )
  119. token_encoder = text_tokenizer.create_encoder(
  120. task="translation", lang=source_lang, mode="source", device=ctx.device
  121. )
  122. pipeline_builder.map(
  123. [token_encoder],
  124. selector="src_text",
  125. num_parallel_calls=n_parallel,
  126. )
  127. pipeline_builder.bucket(bucket_size=ctx.batch_size)
  128. collate = Collater(pad_value=0, pad_to_multiple=1)
  129. pipeline_builder.map(collate, num_parallel_calls=n_parallel)
  130. pipeline_builder.prefetch(4)
  131. return pipeline_builder.and_return()
  132. def adjust_output_for_corrupted_inputs(
  133. valid_sequences: Tensor,
  134. text_output: List[StringLike],
  135. speech_output: Optional[BatchedSpeechOutput],
  136. ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
  137. adjusted_text_output: List[StringLike] = []
  138. adjusted_speech_output: Optional[BatchedSpeechOutput] = None
  139. if speech_output is not None:
  140. assert (
  141. len(text_output)
  142. == len(speech_output.units)
  143. == len(speech_output.audio_wavs)
  144. )
  145. adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
  146. batch_counter = 0
  147. for is_valid in valid_sequences:
  148. if is_valid:
  149. adjusted_text_output.append(text_output[batch_counter])
  150. if speech_output is not None:
  151. assert adjusted_speech_output is not None
  152. adjusted_speech_output.units.append(speech_output.units[batch_counter])
  153. adjusted_speech_output.audio_wavs.append(
  154. speech_output.audio_wavs[batch_counter]
  155. )
  156. batch_counter += 1
  157. else:
  158. # For the corrupted inputs, we save the following dummy outputs:
  159. # empty string for text, empty list for units, 1 second of silence for audio.
  160. adjusted_text_output.append("")
  161. if adjusted_speech_output is not None:
  162. sample_rate = adjusted_speech_output.sample_rate
  163. adjusted_speech_output.units.append([])
  164. adjusted_speech_output.audio_wavs.append(
  165. torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
  166. )
  167. return (
  168. adjusted_text_output,
  169. adjusted_speech_output,
  170. )
  171. def run_eval(
  172. translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
  173. ) -> None:
  174. pipeline = build_data_pipeline(ctx, text_tokenizer)
  175. total_steps = count_lines(ctx.data_file) - 1
  176. progress_bar = tqdm(total=total_steps)
  177. output_path = ctx.output_path / ctx.data_file.stem
  178. output_path.mkdir(parents=True, exist_ok=True)
  179. if ctx.output_modality == Modality.SPEECH:
  180. waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
  181. waveforms_dir.mkdir(parents=True, exist_ok=True)
  182. hyps = []
  183. refs = []
  184. with open(
  185. output_path / f"text_output-{ctx.data_file.stem}.txt", "w"
  186. ) as hyp_file, open(
  187. output_path / f"unit_output-{ctx.data_file.stem}.txt", "w"
  188. ) if ctx.output_modality == Modality.SPEECH else contextlib.nullcontext(
  189. itertools.repeat(None)
  190. ) as unit_file:
  191. sample_id = 0
  192. for example in pipeline:
  193. valid_sequences: Optional[Tensor] = None
  194. if ctx.input_modality == Modality.SPEECH:
  195. src = example["audio"]["data"]["fbank"]
  196. # Skip corrupted audio tensors.
  197. valid_sequences = ~torch.any(
  198. torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
  199. )
  200. if not valid_sequences.all():
  201. logger.warning(
  202. f"Sample IDs {sample_id} to {sample_id + ctx.batch_size} has some corrupted input."
  203. )
  204. src["seqs"] = src["seqs"][valid_sequences]
  205. src["seq_lens"] = src["seq_lens"][valid_sequences]
  206. else:
  207. src = example["src_text"]
  208. # Skip performing inference when the input is entirely corrupted.
  209. if src["seqs"].numel() > 0:
  210. (text_output, speech_output,) = translator.predict(
  211. src,
  212. ctx.task,
  213. ctx.target_lang,
  214. src_lang=ctx.source_lang,
  215. text_generation_opts=ctx.text_generation_opts,
  216. unit_generation_opts=ctx.unit_generation_opts,
  217. unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
  218. )
  219. else:
  220. text_output = []
  221. if ctx.output_modality == Modality.SPEECH:
  222. speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
  223. else:
  224. speech_output = None
  225. if valid_sequences is not None and not valid_sequences.all():
  226. (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
  227. valid_sequences,
  228. text_output,
  229. speech_output,
  230. )
  231. hyps += [str(s) for s in text_output]
  232. refs += [str(s) for s in example[ctx.ref_field]]
  233. for i in range(len(text_output)):
  234. t = text_output[i]
  235. hyp_file.write(f"{t}\n")
  236. if ctx.output_modality == Modality.SPEECH:
  237. assert speech_output is not None
  238. u = speech_output.units[i]
  239. str_units = [str(i) for i in u]
  240. unit_file.write(" ".join(str_units) + "\n")
  241. torchaudio.save(
  242. waveforms_dir / f"{sample_id}_pred.wav",
  243. speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
  244. sample_rate=speech_output.sample_rate,
  245. )
  246. sample_id += 1
  247. progress_bar.update(1)
  248. progress_bar.close()
  249. logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
  250. assert len(hyps) == len(refs)
  251. if len(hyps) > 0:
  252. if ctx.target_lang in ("cmn", "jpn", "lao", "mya", "tha"):
  253. tokenizer = "char"
  254. else:
  255. tokenizer = "13a"
  256. bleu = BLEU(tokenize=tokenizer)
  257. score = bleu.corpus_score(hyps, [refs])
  258. bleu_filename = output_path / f"{ctx.data_file.stem}_text_output_bleu.json"
  259. with open(bleu_filename, "w") as f:
  260. f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
  261. logger.info(score.format(signature=bleu.get_signature()))
  262. def main():
  263. parser = argparse.ArgumentParser(
  264. description="M4T evaluation for tasks supported by Translator."
  265. )
  266. parser.add_argument("data_file", type=str, help="Data file (.tsv) to be evaluated.")
  267. parser = add_inference_arguments(parser)
  268. parser.add_argument(
  269. "--batch_size",
  270. type=int,
  271. help="Inference batch size.",
  272. default=4,
  273. )
  274. parser.add_argument(
  275. "--audio_root_dir",
  276. type=str,
  277. help="Root directory for the audio filenames in the data file.",
  278. required=True,
  279. )
  280. parser.add_argument(
  281. "--ref_field",
  282. type=str,
  283. help="Reference target text field to compute the BLEU score against.",
  284. default="tgt_text",
  285. )
  286. args = parser.parse_args()
  287. input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
  288. if torch.cuda.is_available():
  289. device = torch.device("cuda:0")
  290. dtype = torch.float16
  291. else:
  292. device = torch.device("cpu")
  293. dtype = torch.float32
  294. text_tokenizer = load_unity_text_tokenizer(args.model_name)
  295. # TODO: Avoid loading the T2U model, vocoder when the output
  296. # modality is text.
  297. translator = Translator(
  298. args.model_name,
  299. args.vocoder_name,
  300. device,
  301. text_tokenizer=text_tokenizer,
  302. dtype=dtype,
  303. )
  304. text_generation_opts, unit_generation_opts = set_generation_opts(args)
  305. logger.info(f"{text_generation_opts=}")
  306. logger.info(f"{unit_generation_opts=}")
  307. logger.info(
  308. f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
  309. )
  310. # fmt: off
  311. ctx = EvalContext(
  312. task=args.task,
  313. input_modality=input_modality,
  314. output_modality=output_modality,
  315. model_name=args.model_name,
  316. data_file=Path(args.data_file),
  317. audio_root_dir=Path(args.audio_root_dir),
  318. target_lang=args.tgt_lang,
  319. source_lang=args.src_lang,
  320. batch_size=args.batch_size,
  321. device=device,
  322. dtype=dtype,
  323. ref_field=args.ref_field,
  324. text_generation_opts=text_generation_opts,
  325. unit_generation_opts=unit_generation_opts,
  326. unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
  327. output_path=Path(args.output_path),
  328. )
  329. # fmt: on
  330. logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
  331. run_eval(translator, text_tokenizer, ctx)
  332. if __name__ == "__main__":
  333. main()