evaluate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import contextlib
  8. import logging
  9. from argparse import Namespace
  10. from pathlib import Path
  11. from typing import Optional
  12. import pandas as pd
  13. import torch
  14. import torchaudio
  15. from fairseq2.data import Collater, DataPipeline, FileMapper
  16. from fairseq2.data.audio import (
  17. AudioDecoder,
  18. WaveformToFbankConverter,
  19. WaveformToFbankOutput,
  20. )
  21. from fairseq2.data.text import StrSplitter, read_text
  22. from fairseq2.typing import DataType, Device
  23. from torch import Tensor
  24. from tqdm import tqdm
  25. from seamless_communication.cli.m4t.evaluate.evaluate import (
  26. adjust_output_for_corrupted_inputs,
  27. count_lines,
  28. )
  29. from seamless_communication.cli.m4t.predict import (
  30. add_inference_arguments,
  31. set_generation_opts,
  32. )
  33. from seamless_communication.inference.pretssel_generator import (
  34. PretsselGenerator,
  35. )
  36. from seamless_communication.inference import BatchedSpeechOutput, Translator
  37. from seamless_communication.models.unity import (
  38. load_gcmvn_stats,
  39. load_unity_unit_tokenizer,
  40. )
  41. from seamless_communication.store import add_gated_assets
  42. logging.basicConfig(
  43. level=logging.INFO,
  44. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  45. )
  46. logger = logging.getLogger(__name__)
  47. def build_data_pipeline(
  48. args: Namespace,
  49. device: Device,
  50. dtype: DataType,
  51. gcmvn_mean: Tensor,
  52. gcmvn_std: Tensor,
  53. ) -> DataPipeline:
  54. with open(args.data_file, "r") as f:
  55. header = f.readline().strip("\n").split("\t")
  56. assert (
  57. args.audio_field in header
  58. ), f"Input file does not contain {args.audio_field} field"
  59. n_parallel = 4
  60. split_tsv = StrSplitter(names=header)
  61. pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
  62. assert args.audio_root_dir is not None
  63. map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
  64. pipeline_builder.map(
  65. map_file, selector=args.audio_field, num_parallel_calls=n_parallel
  66. )
  67. decode_audio = AudioDecoder(dtype=torch.float32, device=device)
  68. convert_to_fbank = WaveformToFbankConverter(
  69. num_mel_bins=80,
  70. waveform_scale=2**15,
  71. channel_last=True,
  72. standardize=False,
  73. device=device,
  74. dtype=dtype,
  75. )
  76. def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
  77. fbank = data["fbank"]
  78. std, mean = torch.std_mean(fbank, dim=0)
  79. data["fbank"] = fbank.subtract(mean).divide(std)
  80. data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
  81. return data
  82. pipeline_builder.map(
  83. [decode_audio, convert_to_fbank, normalize_fbank],
  84. selector=f"{args.audio_field}.data",
  85. num_parallel_calls=n_parallel,
  86. )
  87. pipeline_builder.bucket(bucket_size=args.batch_size)
  88. collate = Collater(pad_value=0, pad_to_multiple=1)
  89. pipeline_builder.map(collate, num_parallel_calls=n_parallel)
  90. pipeline_builder.prefetch(4)
  91. return pipeline_builder.and_return()
  92. def main() -> None:
  93. parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference")
  94. parser.add_argument(
  95. "data_file", type=Path, help="Data file (.tsv) to be evaluated."
  96. )
  97. parser = add_inference_arguments(parser)
  98. parser.add_argument(
  99. "--gated-model-dir",
  100. type=Path,
  101. required=False,
  102. help="SeamlessExpressive model directory.",
  103. )
  104. parser.add_argument(
  105. "--batch_size",
  106. type=int,
  107. help="Inference batch size.",
  108. default=4,
  109. )
  110. parser.add_argument(
  111. "--audio_root_dir",
  112. type=Path,
  113. help="Root directory for the audio filenames in the data file.",
  114. default="",
  115. )
  116. parser.add_argument(
  117. "--audio_field",
  118. type=str,
  119. help="Field that includes the input audio file paths.",
  120. default="src_audio",
  121. )
  122. parser.add_argument(
  123. "--ref_field",
  124. type=str,
  125. help="Reference target text field to compute the BLEU score against.",
  126. default=None,
  127. )
  128. parser.add_argument(
  129. "--duration_factor",
  130. type=float,
  131. help="The duration factor for NAR T2U model.",
  132. default=1.0,
  133. )
  134. parser.add_argument(
  135. "--output_result_tsv",
  136. type=bool,
  137. help="Whether to output results in tsv format (for full-blown evaluation)",
  138. default=True,
  139. )
  140. args = parser.parse_args()
  141. if args.gated_model_dir:
  142. add_gated_assets(args.gated_model_dir)
  143. if torch.cuda.is_available():
  144. device = torch.device("cuda:0")
  145. dtype = torch.float16
  146. else:
  147. device = torch.device("cpu")
  148. dtype = torch.float32
  149. unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
  150. _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
  151. gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
  152. gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
  153. pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
  154. translator = Translator(
  155. args.model_name,
  156. vocoder_name_or_card=None,
  157. device=device,
  158. dtype=dtype,
  159. )
  160. text_generation_opts, unit_generation_opts = set_generation_opts(args)
  161. logger.info(f"{text_generation_opts=}")
  162. logger.info(f"{unit_generation_opts=}")
  163. logger.info(
  164. f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
  165. )
  166. pretssel_generator = PretsselGenerator(
  167. args.vocoder_name,
  168. vocab_info=unit_tokenizer.vocab_info,
  169. device=device,
  170. dtype=dtype,
  171. )
  172. total_steps = count_lines(args.data_file) - 1
  173. progress_bar = tqdm(total=total_steps)
  174. output_path = args.output_path / args.data_file.stem
  175. output_path.mkdir(parents=True, exist_ok=True)
  176. waveforms_dir = output_path / "waveform"
  177. waveforms_dir.mkdir(parents=True, exist_ok=True)
  178. hyps = []
  179. refs = []
  180. audio_hyps = []
  181. with contextlib.ExitStack() as stack:
  182. hyp_file = stack.enter_context(
  183. open(output_path / f"text_output-{args.data_file.stem}.txt", "w")
  184. )
  185. unit_file = stack.enter_context(
  186. open(output_path / f"unit_output-{args.data_file.stem}.txt", "w")
  187. )
  188. sample_id = 0
  189. for example in pipeline:
  190. valid_sequences: Optional[Tensor] = None
  191. src = example[args.audio_field]["data"]["fbank"]
  192. # Skip corrupted audio tensors.
  193. valid_sequences = ~torch.any(
  194. torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
  195. )
  196. if not valid_sequences.all():
  197. logger.warning(
  198. f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input."
  199. )
  200. src["seqs"] = src["seqs"][valid_sequences]
  201. src["seq_lens"] = src["seq_lens"][valid_sequences]
  202. # Skip performing inference when the input is entirely corrupted.
  203. if src["seqs"].numel() > 0:
  204. prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
  205. text_output, unit_output = translator.predict(
  206. src,
  207. "s2st",
  208. args.tgt_lang,
  209. src_lang=args.src_lang,
  210. text_generation_opts=text_generation_opts,
  211. unit_generation_opts=unit_generation_opts,
  212. unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
  213. duration_factor=args.duration_factor,
  214. prosody_encoder_input=prosody_encoder_input,
  215. )
  216. assert unit_output is not None
  217. speech_output = pretssel_generator.predict(
  218. unit_output.units,
  219. tgt_lang=args.tgt_lang,
  220. prosody_encoder_input=prosody_encoder_input,
  221. )
  222. else:
  223. text_output = []
  224. speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
  225. if valid_sequences is not None and not valid_sequences.all():
  226. text_output, speech_output = adjust_output_for_corrupted_inputs( # type: ignore[assignment]
  227. valid_sequences,
  228. text_output,
  229. speech_output,
  230. )
  231. hyps += [str(s) for s in text_output]
  232. if args.ref_field is not None and args.ref_field in example:
  233. refs += [str(s) for s in example[args.ref_field]]
  234. for i in range(len(text_output)):
  235. t = text_output[i]
  236. idx = str(example["id"][i])
  237. hyp_file.write(f"{t}\n")
  238. u = speech_output.units[i]
  239. str_units = [str(i) for i in u]
  240. unit_file.write(" ".join(str_units) + "\n")
  241. torchaudio.save(
  242. waveforms_dir / f"{idx}_pred.wav",
  243. speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
  244. sample_rate=speech_output.sample_rate,
  245. )
  246. audio_hyps.append((waveforms_dir / f"{idx}_pred.wav").as_posix())
  247. sample_id += 1
  248. progress_bar.update(1)
  249. progress_bar.close()
  250. logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
  251. if args.output_result_tsv:
  252. output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
  253. output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
  254. text_out = []
  255. with open(hyp_file.name) as file:
  256. for line in file:
  257. text_out.append(line.strip())
  258. unit_out = []
  259. with open(unit_file.name) as file:
  260. for line in file:
  261. unit_out.append(line.strip())
  262. output_tsv["hypo_audio"] = audio_hyps
  263. output_tsv["s2t_out"] = text_out
  264. output_tsv["orig_unit"] = unit_out
  265. output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
  266. logger.info(f"Output results in {output_tsv_file}")
  267. if __name__ == "__main__":
  268. main()