compute_metrics.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from pathlib import Path
  7. import logging
  8. import pandas as pd
  9. import sacrebleu
  10. import whisper
  11. from jiwer import wer, cer
  12. from tqdm import tqdm
  13. from typing import Optional
  14. from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
  15. from scripts.eval_utils.lang_mapping import LANG3_LANG2
  16. logging.basicConfig(
  17. level=logging.INFO,
  18. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  19. )
  20. logger = logging.getLogger(__name__)
  21. def init_whisper_model(
  22. device: str,
  23. whisper_model_name: str = "large",
  24. ):
  25. return whisper.load_model(name=whisper_model_name, device=device)
  26. def transcribe_series(
  27. audio_paths_series: pd.Series,
  28. asr_model,
  29. audio_lang: str,
  30. beam_size: int = 1,
  31. temperature: float = 0.0,
  32. ):
  33. """Transcribes each audio filepath from series and returns series of transcriptions
  34. Args:
  35. audio_paths_series (pd.Series): each line contains path to audio file.
  36. asr_model: ASR model to do the transcribing process e.g. Whisper
  37. audio_lang (str): what language is used in the given audio, used by ASR model
  38. beam_size (int): whisper beam size. Defaults to 1
  39. temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below).
  40. Returns:
  41. pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series
  42. Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147
  43. The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e.
  44. start over is fired. Number of fallback iterations is determined by the schedule of temperature values:
  45. https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41
  46. By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and
  47. turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration.
  48. Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs.
  49. This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in
  50. the transcriptions and lead to worse ASR-BLEU score in the end.
  51. """
  52. if len(audio_lang) == 3:
  53. # to make it work with whisper
  54. audio_lang = LANG3_LANG2[audio_lang]
  55. transcriptions = {}
  56. for idx, audio_path in tqdm(
  57. audio_paths_series.items(),
  58. desc=f"Transcribing {audio_paths_series.name} column",
  59. total=len(audio_paths_series),
  60. ):
  61. hypo = asr_model.transcribe(
  62. audio_path,
  63. temperature=temperature,
  64. beam_size=beam_size,
  65. language=audio_lang,
  66. )["text"].strip()
  67. transcriptions[idx] = hypo
  68. transcriptions_series = pd.Series(transcriptions)
  69. transcriptions_series.name = f"{audio_paths_series.name}_transcribed"
  70. return transcriptions_series
  71. def whisper_normalize_series(transcription_series: pd.Series, text_lang: str):
  72. """Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
  73. Args:
  74. transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
  75. text_lang (str): Language of the text in series
  76. Returns:
  77. pd.Series: Series with normalized text
  78. """
  79. if text_lang == "eng":
  80. normalizer = EnglishTextNormalizer()
  81. else:
  82. normalizer = BasicTextNormalizer()
  83. norm_transcriptions = {}
  84. for idx, text in transcription_series.items():
  85. norm_transcriptions[idx] = normalizer(text)
  86. norm_transcriptions_series = pd.Series(norm_transcriptions)
  87. norm_transcriptions_series.name = transcription_series.name
  88. return norm_transcriptions_series
  89. def compute_asr_bleu(
  90. audio_paths_series: pd.Series,
  91. ref_text_series: pd.Series,
  92. lang: str,
  93. asr_model,
  94. whisper_normalize_text: Optional[bool] = True,
  95. beam_size: Optional[int] = 1,
  96. temperature: Optional[float] = 0.0,
  97. return_transcriptions: Optional[bool] = True,
  98. ):
  99. """Wraps functions above to compute corpus-level ASR-BLEU
  100. ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
  101. Args:
  102. audio_paths_series (pd.Series): each line contains path to audio
  103. ref_text_series (pd.Series): each line contains the text reference to compare audio with
  104. lang (str): the language of both audio and ref_text
  105. asr_model: whisper ASR model
  106. whisper_normalize_text (bool, Optional): normalize both text hypotheses and reference if True. Defaults to True.
  107. beam_size (int, Optional): beam_size for whisper generation
  108. temperature (float, Optional): Temperature sampling value for whisper generation
  109. return_transcriptions (bool, Optional)
  110. """
  111. audio_transcriptions = transcribe_series(
  112. audio_paths_series,
  113. asr_model,
  114. audio_lang=lang,
  115. beam_size=beam_size,
  116. temperature=temperature,
  117. )
  118. asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
  119. audio_transcriptions, ref_text_series, lang, whisper_normalize_text
  120. )
  121. asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
  122. asr_bleu_signature.info["whisper_asr_temperature"] = temperature
  123. asr_bleu_signature.info["whisper_asr_language"] = lang
  124. transcript_df = None
  125. if return_transcriptions:
  126. transcript_df = pd.concat(
  127. [
  128. audio_paths_series,
  129. audio_transcriptions,
  130. ref_text_series,
  131. ],
  132. axis=1,
  133. keys=["audio", "transcript", "reference"],
  134. )
  135. return asr_bleu, asr_bleu_signature, transcript_df
  136. def get_tokenizer(lang: str, metric: Optional[str] = "bleu"):
  137. """Get tokenizer for language
  138. Args:
  139. lang (str): Three letter code of the language
  140. metric (str, Optional): Metric being computed. Valid values are "bleu" and "asr"
  141. """
  142. lang_tok_map = {
  143. "cmn": "char",
  144. "jpn": "char",
  145. "tha": "char",
  146. "lao": "char",
  147. "mya": "char",
  148. }
  149. default = (
  150. "13a" if metric == "bleu" else "word"
  151. ) # 13a is the default tokenizer for bleu and wer for asr
  152. tok = lang_tok_map.get(lang, default)
  153. return tok
  154. def compute_asr_error_rate(
  155. hyp_text_series: pd.Series,
  156. ref_text_series: pd.Series,
  157. lang: str,
  158. whisper_normalize_text=True,
  159. ):
  160. """Wraps normalization functions and computes ASR WER/CER score
  161. Args:
  162. hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
  163. ref_text_series (pd.Series): _description_
  164. lang (str): _description_
  165. whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
  166. Returns:
  167. (MetricScore, MetricScoreSignature)
  168. """
  169. if whisper_normalize_text:
  170. hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
  171. ref_text_series = whisper_normalize_series(ref_text_series, lang)
  172. tokenizer_name = get_tokenizer(lang, metric="error_rate")
  173. metric_name = wer if tokenizer_name == "word" else cer
  174. metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list())
  175. return metric_score, f"{metric_name.__name__} is {metric_score}"
  176. def compute_corpus_metric_score(
  177. hyp_text_series: pd.Series,
  178. ref_text_series: pd.Series,
  179. lang: str,
  180. whisper_normalize_text=True,
  181. metric: Optional[str] = "bleu",
  182. ):
  183. """Wraps normalization functions and compute corpus-level BLEU/chrF++ score
  184. Args:
  185. hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
  186. ref_text_series (pd.Series): _description_
  187. lang (str): _description_
  188. whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
  189. Returns:
  190. (MetricScore, MetricScoreSignature)
  191. """
  192. if whisper_normalize_text:
  193. hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
  194. ref_text_series = whisper_normalize_series(ref_text_series, lang)
  195. tokenizer_name = get_tokenizer(lang)
  196. if metric == "bleu":
  197. corpus_metric_score_metric = sacrebleu.metrics.bleu.BLEU(
  198. lowercase=whisper_normalize_text, tokenize=tokenizer_name
  199. ) # lowercase applied if we use whisper_normalize_text
  200. elif metric == "chrF++":
  201. corpus_metric_score_metric = sacrebleu.CHRF(word_order=2)
  202. corpus_metric_score = corpus_metric_score_metric.corpus_score(
  203. hyp_text_series.to_list(), [ref_text_series.to_list()]
  204. )
  205. corpus_metric_score_signature = corpus_metric_score_metric.get_signature()
  206. corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text
  207. return corpus_metric_score, corpus_metric_score_signature
  208. def compute_quality_metrics(
  209. output_manifest_tsv_path: str,
  210. output_dir: str,
  211. tgt_lang: str,
  212. task: str,
  213. device: str,
  214. whisper_model_name: Optional[str] = "large",
  215. whisper_normalize_text_output: Optional[bool] = False,
  216. ref_text_col_name: Optional[str] = "ref_tgt_text",
  217. pred_text_col_name: Optional[str] = "pred_tgt_text",
  218. pred_audio_col_name: Optional[str] = "pred_tgt_audio",
  219. ):
  220. """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
  221. Args:
  222. output_manifest_tsv_path (str): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
  223. output_dir (str): Directory to write files with metrics
  224. tgt_lang (str): what language we evaluate on
  225. task (str): Task we are currently evaluating for
  226. device (str): Device to use for inference
  227. whisper_model_name (str, Optional): Whisper model name. Defaults to "large".
  228. whisper_normalize_text_output (bool, Optional): Normalizes text output using whisper_normalizer if set to true
  229. ref_text_col_name (str, Optional): Column name in the tsv corresponding to reference target text
  230. pred_text_col_name (str, Optional): Column name in the tsv corresponding to predicted target text
  231. pred_audio_col_name (str, Optional): Column name in the tsv corresponding to predicted target audio.
  232. Setting this value to none will skip speech metrics
  233. """
  234. df = pd.read_csv(
  235. output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
  236. )
  237. task = task.upper()
  238. if not Path(output_dir).exists():
  239. Path(output_dir).mkdir(parents=True, exist_ok=True)
  240. if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
  241. metric = "chrF++" if task == "T2TT" else "bleu"
  242. text_metric, text_metric_signature = compute_corpus_metric_score(
  243. hyp_text_series=df[pred_text_col_name],
  244. ref_text_series=df[ref_text_col_name],
  245. lang=tgt_lang,
  246. whisper_normalize_text=whisper_normalize_text_output,
  247. metric=metric,
  248. )
  249. text_metric_json = text_metric.format(
  250. signature=text_metric_signature.format(), is_json=True
  251. )
  252. if task == "T2TT":
  253. filename = "t2tt_chrf.json"
  254. cur_task = "T2TT"
  255. else:
  256. filename = (
  257. "s2tt_bleu_normalized.json"
  258. if whisper_normalize_text_output
  259. else "s2tt_bleu.json"
  260. )
  261. cur_task = "S2TT"
  262. with open((Path(output_dir) / filename).as_posix(), "w") as f:
  263. f.write(text_metric_json)
  264. logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
  265. if task in ["T2ST", "S2ST"]:
  266. whisper_model = init_whisper_model(device, whisper_model_name)
  267. (
  268. asr_bleu_normalized,
  269. asr_bleu_normalized_signature,
  270. transcripts_df,
  271. ) = compute_asr_bleu(
  272. audio_paths_series=df[pred_audio_col_name],
  273. ref_text_series=df[ref_text_col_name],
  274. lang=tgt_lang,
  275. asr_model=whisper_model,
  276. whisper_normalize_text=True,
  277. )
  278. transcripts_df.to_csv(
  279. (Path(output_dir) / f"whisper_audio_transcriptions.tsv"),
  280. sep="\t",
  281. index=False,
  282. encoding="utf-8",
  283. escapechar="\\",
  284. )
  285. asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
  286. asr_bleu_normalized_json = asr_bleu_normalized.format(
  287. signature=asr_bleu_normalized_signature.format(), is_json=True
  288. )
  289. with open(
  290. (Path(output_dir) / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
  291. "w",
  292. ) as f:
  293. f.write(asr_bleu_normalized_json)
  294. with open((Path(output_dir) / filename).as_posix(), "w") as f:
  295. f.write(text_metric_json)
  296. logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
  297. if task == "ASR":
  298. asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
  299. hyp_text_series=df[pred_text_col_name],
  300. ref_text_series=df[ref_text_col_name],
  301. lang=tgt_lang,
  302. whisper_normalize_text=whisper_normalize_text_output,
  303. )
  304. with open((Path(output_dir) / "asr_error_rate.json").as_posix(), "w") as f:
  305. f.write(asr_error_rate_signature)
  306. logger.info(f"ASR : {asr_error_rate_signature}")
  307. return