|
@@ -0,0 +1,356 @@
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
+# All rights reserved.
|
|
|
+#
|
|
|
+# This source code is licensed under the license found in the
|
|
|
+# LICENSE file in the root directory of this source tree.
|
|
|
+
|
|
|
+from pathlib import Path
|
|
|
+import logging
|
|
|
+import pandas as pd
|
|
|
+import sacrebleu
|
|
|
+import whisper
|
|
|
+from jiwer import wer, cer
|
|
|
+from tqdm import tqdm
|
|
|
+from typing import Optional
|
|
|
+from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
|
|
|
+
|
|
|
+from scripts.eval_utils.lang_mapping import LANG3_LANG2
|
|
|
+
|
|
|
+logging.basicConfig(
|
|
|
+ level=logging.INFO,
|
|
|
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
|
|
+)
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+def init_whisper_model(
|
|
|
+ device: str,
|
|
|
+ whisper_model_name: str = "large",
|
|
|
+):
|
|
|
+ return whisper.load_model(name=whisper_model_name, device=device)
|
|
|
+
|
|
|
+
|
|
|
+def transcribe_series(
|
|
|
+ audio_paths_series: pd.Series,
|
|
|
+ asr_model,
|
|
|
+ audio_lang: str,
|
|
|
+ beam_size: int = 1,
|
|
|
+ temperature: float = 0.0,
|
|
|
+):
|
|
|
+ """Transcribes each audio filepath from series and returns series of transcriptions
|
|
|
+ Args:
|
|
|
+ audio_paths_series (pd.Series): each line contains path to audio file.
|
|
|
+ asr_model: ASR model to do the transcribing process e.g. Whisper
|
|
|
+ audio_lang (str): what language is used in the given audio, used by ASR model
|
|
|
+ beam_size (int): whisper beam size. Defaults to 1
|
|
|
+ temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below).
|
|
|
+ Returns:
|
|
|
+ pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series
|
|
|
+ Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147
|
|
|
+ The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e.
|
|
|
+ start over is fired. Number of fallback iterations is determined by the schedule of temperature values:
|
|
|
+ https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41
|
|
|
+ By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and
|
|
|
+ turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration.
|
|
|
+ Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs.
|
|
|
+ This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in
|
|
|
+ the transcriptions and lead to worse ASR-BLEU score in the end.
|
|
|
+ """
|
|
|
+
|
|
|
+ if len(audio_lang) == 3:
|
|
|
+ # to make it work with whisper
|
|
|
+ audio_lang = LANG3_LANG2[audio_lang]
|
|
|
+
|
|
|
+ transcriptions = {}
|
|
|
+
|
|
|
+ for idx, audio_path in tqdm(
|
|
|
+ audio_paths_series.items(),
|
|
|
+ desc=f"Transcribing {audio_paths_series.name} column",
|
|
|
+ total=len(audio_paths_series),
|
|
|
+ ):
|
|
|
+ hypo = asr_model.transcribe(
|
|
|
+ audio_path,
|
|
|
+ temperature=temperature,
|
|
|
+ beam_size=beam_size,
|
|
|
+ language=audio_lang,
|
|
|
+ )["text"].strip()
|
|
|
+ transcriptions[idx] = hypo
|
|
|
+
|
|
|
+ transcriptions_series = pd.Series(transcriptions)
|
|
|
+ transcriptions_series.name = f"{audio_paths_series.name}_transcribed"
|
|
|
+
|
|
|
+ return transcriptions_series
|
|
|
+
|
|
|
+
|
|
|
+def whisper_normalize_series(transcription_series: pd.Series, text_lang: str):
|
|
|
+ """Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
|
|
|
+ Args:
|
|
|
+ transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
|
|
|
+ text_lang (str): Language of the text in series
|
|
|
+ Returns:
|
|
|
+ pd.Series: Series with normalized text
|
|
|
+ """
|
|
|
+ if text_lang == "eng":
|
|
|
+ normalizer = EnglishTextNormalizer()
|
|
|
+ else:
|
|
|
+ normalizer = BasicTextNormalizer()
|
|
|
+
|
|
|
+ norm_transcriptions = {}
|
|
|
+
|
|
|
+ for idx, text in transcription_series.items():
|
|
|
+ norm_transcriptions[idx] = normalizer(text)
|
|
|
+
|
|
|
+ norm_transcriptions_series = pd.Series(norm_transcriptions)
|
|
|
+ norm_transcriptions_series.name = transcription_series.name
|
|
|
+
|
|
|
+ return norm_transcriptions_series
|
|
|
+
|
|
|
+
|
|
|
+def compute_asr_bleu(
|
|
|
+ audio_paths_series: pd.Series,
|
|
|
+ ref_text_series: pd.Series,
|
|
|
+ lang: str,
|
|
|
+ asr_model,
|
|
|
+ whisper_normalize_text: Optional[bool] = True,
|
|
|
+ beam_size: Optional[int] = 1,
|
|
|
+ temperature: Optional[float] = 0.0,
|
|
|
+ return_transcriptions: Optional[bool] = True,
|
|
|
+):
|
|
|
+ """Wraps functions above to compute corpus-level ASR-BLEU
|
|
|
+ ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
|
|
|
+ Args:
|
|
|
+ audio_paths_series (pd.Series): each line contains path to audio
|
|
|
+ ref_text_series (pd.Series): each line contains the text reference to compare audio with
|
|
|
+ lang (str): the language of both audio and ref_text
|
|
|
+ asr_model: whisper ASR model
|
|
|
+ whisper_normalize_text (bool, Optional): normalize both text hypotheses and reference if True. Defaults to True.
|
|
|
+ beam_size (int, Optional): beam_size for whisper generation
|
|
|
+ temperature (float, Optional): Temperature sampling value for whisper generation
|
|
|
+ return_transcriptions (bool, Optional)
|
|
|
+ """
|
|
|
+
|
|
|
+ audio_transcriptions = transcribe_series(
|
|
|
+ audio_paths_series,
|
|
|
+ asr_model,
|
|
|
+ audio_lang=lang,
|
|
|
+ beam_size=beam_size,
|
|
|
+ temperature=temperature,
|
|
|
+ )
|
|
|
+ asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
|
|
|
+ audio_transcriptions, ref_text_series, lang, whisper_normalize_text
|
|
|
+ )
|
|
|
+ asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
|
|
|
+ asr_bleu_signature.info["whisper_asr_temperature"] = temperature
|
|
|
+ asr_bleu_signature.info["whisper_asr_language"] = lang
|
|
|
+
|
|
|
+ transcript_df = None
|
|
|
+ if return_transcriptions:
|
|
|
+ transcript_df = pd.concat(
|
|
|
+ [
|
|
|
+ audio_paths_series,
|
|
|
+ audio_transcriptions,
|
|
|
+ ref_text_series,
|
|
|
+ ],
|
|
|
+ axis=1,
|
|
|
+ keys=["audio", "transcript", "reference"],
|
|
|
+ )
|
|
|
+ return asr_bleu, asr_bleu_signature, transcript_df
|
|
|
+
|
|
|
+
|
|
|
+def get_tokenizer(lang: str, metric: Optional[str] = "bleu"):
|
|
|
+ """Get tokenizer for language
|
|
|
+ Args:
|
|
|
+ lang (str): Three letter code of the language
|
|
|
+ metric (str, Optional): Metric being computed. Valid values are "bleu" and "asr"
|
|
|
+ """
|
|
|
+ lang_tok_map = {
|
|
|
+ "cmn": "char",
|
|
|
+ "jpn": "char",
|
|
|
+ "tha": "char",
|
|
|
+ "lao": "char",
|
|
|
+ "mya": "char",
|
|
|
+ }
|
|
|
+ default = (
|
|
|
+ "13a" if metric == "bleu" else "word"
|
|
|
+ ) # 13a is the default tokenizer for bleu and wer for asr
|
|
|
+ tok = lang_tok_map.get(lang, default)
|
|
|
+ return tok
|
|
|
+
|
|
|
+
|
|
|
+def compute_asr_error_rate(
|
|
|
+ hyp_text_series: pd.Series,
|
|
|
+ ref_text_series: pd.Series,
|
|
|
+ lang: str,
|
|
|
+ whisper_normalize_text=True,
|
|
|
+):
|
|
|
+ """Wraps normalization functions and computes ASR WER/CER score
|
|
|
+ Args:
|
|
|
+ hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
|
|
|
+ ref_text_series (pd.Series): _description_
|
|
|
+ lang (str): _description_
|
|
|
+ whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
|
|
|
+ Returns:
|
|
|
+ (MetricScore, MetricScoreSignature)
|
|
|
+ """
|
|
|
+ if whisper_normalize_text:
|
|
|
+ hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
|
|
|
+ ref_text_series = whisper_normalize_series(ref_text_series, lang)
|
|
|
+
|
|
|
+ tokenizer_name = get_tokenizer(lang, metric="error_rate")
|
|
|
+ metric_name = wer if tokenizer_name == "word" else cer
|
|
|
+ metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list())
|
|
|
+ return metric_score, f"{metric_name.__name__} is {metric_score}"
|
|
|
+
|
|
|
+
|
|
|
+def compute_corpus_metric_score(
|
|
|
+ hyp_text_series: pd.Series,
|
|
|
+ ref_text_series: pd.Series,
|
|
|
+ lang: str,
|
|
|
+ whisper_normalize_text=True,
|
|
|
+ metric: Optional[str] = "bleu",
|
|
|
+):
|
|
|
+ """Wraps normalization functions and compute corpus-level BLEU/chrF++ score
|
|
|
+ Args:
|
|
|
+ hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
|
|
|
+ ref_text_series (pd.Series): _description_
|
|
|
+ lang (str): _description_
|
|
|
+ whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
|
|
|
+ Returns:
|
|
|
+ (MetricScore, MetricScoreSignature)
|
|
|
+ """
|
|
|
+ if whisper_normalize_text:
|
|
|
+ hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
|
|
|
+ ref_text_series = whisper_normalize_series(ref_text_series, lang)
|
|
|
+
|
|
|
+ tokenizer_name = get_tokenizer(lang)
|
|
|
+ if metric == "bleu":
|
|
|
+ corpus_metric_score_metric = sacrebleu.metrics.bleu.BLEU(
|
|
|
+ lowercase=whisper_normalize_text, tokenize=tokenizer_name
|
|
|
+ ) # lowercase applied if we use whisper_normalize_text
|
|
|
+ elif metric == "chrF++":
|
|
|
+ corpus_metric_score_metric = sacrebleu.CHRF(word_order=2)
|
|
|
+
|
|
|
+ corpus_metric_score = corpus_metric_score_metric.corpus_score(
|
|
|
+ hyp_text_series.to_list(), [ref_text_series.to_list()]
|
|
|
+ )
|
|
|
+ corpus_metric_score_signature = corpus_metric_score_metric.get_signature()
|
|
|
+ corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text
|
|
|
+
|
|
|
+ return corpus_metric_score, corpus_metric_score_signature
|
|
|
+
|
|
|
+
|
|
|
+def compute_quality_metrics(
|
|
|
+ output_manifest_tsv_path: str,
|
|
|
+ output_dir: str,
|
|
|
+ tgt_lang: str,
|
|
|
+ task: str,
|
|
|
+ device: str,
|
|
|
+ whisper_model_name: Optional[str] = "large",
|
|
|
+ whisper_normalize_text_output: Optional[bool] = False,
|
|
|
+ ref_text_col_name: Optional[str] = "ref_tgt_text",
|
|
|
+ pred_text_col_name: Optional[str] = "pred_tgt_text",
|
|
|
+ pred_audio_col_name: Optional[str] = "pred_tgt_audio",
|
|
|
+):
|
|
|
+ """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
|
|
|
+ Args:
|
|
|
+ output_manifest_tsv_path (str): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
|
|
|
+ output_dir (str): Directory to write files with metrics
|
|
|
+ tgt_lang (str): what language we evaluate on
|
|
|
+ task (str): Task we are currently evaluating for
|
|
|
+ device (str): Device to use for inference
|
|
|
+ whisper_model_name (str, Optional): Whisper model name. Defaults to "large".
|
|
|
+ whisper_normalize_text_output (bool, Optional): Normalizes text output using whisper_normalizer if set to true
|
|
|
+ ref_text_col_name (str, Optional): Column name in the tsv corresponding to reference target text
|
|
|
+ pred_text_col_name (str, Optional): Column name in the tsv corresponding to predicted target text
|
|
|
+ pred_audio_col_name (str, Optional): Column name in the tsv corresponding to predicted target audio.
|
|
|
+ Setting this value to none will skip speech metrics
|
|
|
+ """
|
|
|
+ df = pd.read_csv(
|
|
|
+ output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
|
|
|
+ )
|
|
|
+ task = task.upper()
|
|
|
+
|
|
|
+ if not Path(output_dir).exists():
|
|
|
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
|
|
|
+ metric = "chrF++" if task == "T2TT" else "bleu"
|
|
|
+ text_metric, text_metric_signature = compute_corpus_metric_score(
|
|
|
+ hyp_text_series=df[pred_text_col_name],
|
|
|
+ ref_text_series=df[ref_text_col_name],
|
|
|
+ lang=tgt_lang,
|
|
|
+ whisper_normalize_text=whisper_normalize_text_output,
|
|
|
+ metric=metric,
|
|
|
+ )
|
|
|
+ text_metric_json = text_metric.format(
|
|
|
+ signature=text_metric_signature.format(), is_json=True
|
|
|
+ )
|
|
|
+
|
|
|
+ if task == "T2TT":
|
|
|
+ filename = "t2tt_chrf.json"
|
|
|
+ cur_task = "T2TT"
|
|
|
+ else:
|
|
|
+ filename = (
|
|
|
+ "s2tt_bleu_normalized.json"
|
|
|
+ if whisper_normalize_text_output
|
|
|
+ else "s2tt_bleu.json"
|
|
|
+ )
|
|
|
+ cur_task = "S2TT"
|
|
|
+ with open((Path(output_dir) / filename).as_posix(), "w") as f:
|
|
|
+ f.write(text_metric_json)
|
|
|
+
|
|
|
+ logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
|
|
|
+
|
|
|
+ if task in ["T2ST", "S2ST"]:
|
|
|
+ whisper_model = init_whisper_model(device, whisper_model_name)
|
|
|
+ (
|
|
|
+ asr_bleu_normalized,
|
|
|
+ asr_bleu_normalized_signature,
|
|
|
+ transcripts_df,
|
|
|
+ ) = compute_asr_bleu(
|
|
|
+ audio_paths_series=df[pred_audio_col_name],
|
|
|
+ ref_text_series=df[ref_text_col_name],
|
|
|
+ lang=tgt_lang,
|
|
|
+ asr_model=whisper_model,
|
|
|
+ whisper_normalize_text=True,
|
|
|
+ )
|
|
|
+ transcripts_df.to_csv(
|
|
|
+ (Path(output_dir) / f"whisper_audio_transcriptions.tsv"),
|
|
|
+ sep="\t",
|
|
|
+ index=False,
|
|
|
+ encoding="utf-8",
|
|
|
+ escapechar="\\",
|
|
|
+ )
|
|
|
+
|
|
|
+ asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
|
|
|
+
|
|
|
+ asr_bleu_normalized_json = asr_bleu_normalized.format(
|
|
|
+ signature=asr_bleu_normalized_signature.format(), is_json=True
|
|
|
+ )
|
|
|
+
|
|
|
+ with open(
|
|
|
+ (Path(output_dir) / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
|
|
|
+ "w",
|
|
|
+ ) as f:
|
|
|
+ f.write(asr_bleu_normalized_json)
|
|
|
+
|
|
|
+ with open((Path(output_dir) / filename).as_posix(), "w") as f:
|
|
|
+ f.write(text_metric_json)
|
|
|
+
|
|
|
+ logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
|
|
|
+
|
|
|
+ if task == "ASR":
|
|
|
+ asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
|
|
|
+ hyp_text_series=df[pred_text_col_name],
|
|
|
+ ref_text_series=df[ref_text_col_name],
|
|
|
+ lang=tgt_lang,
|
|
|
+ whisper_normalize_text=whisper_normalize_text_output,
|
|
|
+ )
|
|
|
+
|
|
|
+ with open((Path(output_dir) / "asr_error_rate.json").as_posix(), "w") as f:
|
|
|
+ f.write(asr_error_rate_signature)
|
|
|
+
|
|
|
+ logger.info(f"ASR : {asr_error_rate_signature}")
|
|
|
+
|
|
|
+ return
|