|
@@ -6,16 +6,21 @@
|
|
|
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
-from typing import Optional
|
|
|
+from typing import Tuple, Union
|
|
|
|
|
|
import pandas as pd
|
|
|
-import sacrebleu
|
|
|
import whisper
|
|
|
+
|
|
|
+from fairseq2.typing import Device
|
|
|
from jiwer import cer, wer
|
|
|
+from sacrebleu.metrics.base import Score, Signature
|
|
|
+from sacrebleu.metrics.bleu import BLEU
|
|
|
+from sacrebleu.metrics.chrf import CHRF
|
|
|
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
|
|
|
from tqdm import tqdm
|
|
|
+from whisper import Whisper
|
|
|
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
|
|
|
|
|
|
-from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
@@ -26,19 +31,19 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
def init_whisper_model(
|
|
|
- device: str,
|
|
|
+ device: Device,
|
|
|
whisper_model_name: str = "large",
|
|
|
-):
|
|
|
+) -> Whisper:
|
|
|
return whisper.load_model(name=whisper_model_name, device=device)
|
|
|
|
|
|
|
|
|
def transcribe_series(
|
|
|
audio_paths_series: pd.Series,
|
|
|
- asr_model,
|
|
|
+ asr_model: Whisper,
|
|
|
audio_lang: str,
|
|
|
beam_size: int = 1,
|
|
|
temperature: float = 0.0,
|
|
|
-):
|
|
|
+) -> pd.Series:
|
|
|
"""Transcribes each audio filepath from series and returns series of transcriptions
|
|
|
Args:
|
|
|
audio_paths_series (pd.Series): each line contains path to audio file.
|
|
@@ -84,7 +89,9 @@ def transcribe_series(
|
|
|
return transcriptions_series
|
|
|
|
|
|
|
|
|
-def whisper_normalize_series(transcription_series: pd.Series, text_lang: str):
|
|
|
+def whisper_normalize_series(
|
|
|
+ transcription_series: pd.Series, text_lang: str
|
|
|
+) -> pd.Series:
|
|
|
"""Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
|
|
|
Args:
|
|
|
transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
|
|
@@ -112,12 +119,12 @@ def compute_asr_bleu(
|
|
|
audio_paths_series: pd.Series,
|
|
|
ref_text_series: pd.Series,
|
|
|
lang: str,
|
|
|
- asr_model,
|
|
|
- whisper_normalize_text: Optional[bool] = True,
|
|
|
- beam_size: Optional[int] = 1,
|
|
|
- temperature: Optional[float] = 0.0,
|
|
|
- return_transcriptions: Optional[bool] = True,
|
|
|
-):
|
|
|
+ asr_model: Whisper,
|
|
|
+ whisper_normalize_text: bool = True,
|
|
|
+ beam_size: int = 1,
|
|
|
+ temperature: float = 0.0,
|
|
|
+ return_transcriptions: bool = True,
|
|
|
+) -> Tuple[Score, Signature, pd.DataFrame]:
|
|
|
"""Wraps functions above to compute corpus-level ASR-BLEU
|
|
|
ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
|
|
|
Args:
|
|
@@ -125,10 +132,10 @@ def compute_asr_bleu(
|
|
|
ref_text_series (pd.Series): each line contains the text reference to compare audio with
|
|
|
lang (str): the language of both audio and ref_text
|
|
|
asr_model: whisper ASR model
|
|
|
- whisper_normalize_text (bool, Optional): normalize both text hypotheses and reference if True. Defaults to True.
|
|
|
- beam_size (int, Optional): beam_size for whisper generation
|
|
|
- temperature (float, Optional): Temperature sampling value for whisper generation
|
|
|
- return_transcriptions (bool, Optional)
|
|
|
+ whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True.
|
|
|
+ beam_size (int): beam_size for whisper generation
|
|
|
+ temperature (float): Temperature sampling value for whisper generation
|
|
|
+ return_transcriptions (bool)
|
|
|
"""
|
|
|
|
|
|
audio_transcriptions = transcribe_series(
|
|
@@ -159,11 +166,11 @@ def compute_asr_bleu(
|
|
|
return asr_bleu, asr_bleu_signature, transcript_df
|
|
|
|
|
|
|
|
|
-def get_tokenizer(lang: str, metric: Optional[str] = "bleu"):
|
|
|
+def get_tokenizer(lang: str, metric: str = "bleu") -> str:
|
|
|
"""Get tokenizer for language
|
|
|
Args:
|
|
|
lang (str): Three letter code of the language
|
|
|
- metric (str, Optional): Metric being computed. Valid values are "bleu" and "asr"
|
|
|
+ metric (str): Metric being computed. Valid values are "bleu" and "asr"
|
|
|
"""
|
|
|
lang_tok_map = {
|
|
|
"cmn": "char",
|
|
@@ -183,8 +190,8 @@ def compute_asr_error_rate(
|
|
|
hyp_text_series: pd.Series,
|
|
|
ref_text_series: pd.Series,
|
|
|
lang: str,
|
|
|
- whisper_normalize_text=True,
|
|
|
-):
|
|
|
+ whisper_normalize_text: bool = True,
|
|
|
+) -> Tuple[Score, str]:
|
|
|
"""Wraps normalization functions and computes ASR WER/CER score
|
|
|
Args:
|
|
|
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
|
|
@@ -208,9 +215,9 @@ def compute_corpus_metric_score(
|
|
|
hyp_text_series: pd.Series,
|
|
|
ref_text_series: pd.Series,
|
|
|
lang: str,
|
|
|
- whisper_normalize_text=True,
|
|
|
- metric: Optional[str] = "bleu",
|
|
|
-):
|
|
|
+ whisper_normalize_text: bool = True,
|
|
|
+ metric: str = "bleu",
|
|
|
+) -> Tuple[Score, Signature]:
|
|
|
"""Wraps normalization functions and compute corpus-level BLEU/chrF++ score
|
|
|
Args:
|
|
|
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
|
|
@@ -225,12 +232,13 @@ def compute_corpus_metric_score(
|
|
|
ref_text_series = whisper_normalize_series(ref_text_series, lang)
|
|
|
|
|
|
tokenizer_name = get_tokenizer(lang)
|
|
|
+ corpus_metric_score_metric: Union[BLEU, CHRF]
|
|
|
if metric == "bleu":
|
|
|
- corpus_metric_score_metric = sacrebleu.metrics.bleu.BLEU(
|
|
|
+ corpus_metric_score_metric = BLEU(
|
|
|
lowercase=whisper_normalize_text, tokenize=tokenizer_name
|
|
|
) # lowercase applied if we use whisper_normalize_text
|
|
|
elif metric == "chrF++":
|
|
|
- corpus_metric_score_metric = sacrebleu.CHRF(word_order=2)
|
|
|
+ corpus_metric_score_metric = CHRF(word_order=2)
|
|
|
|
|
|
corpus_metric_score = corpus_metric_score_metric.corpus_score(
|
|
|
hyp_text_series.to_list(), [ref_text_series.to_list()]
|
|
@@ -242,29 +250,29 @@ def compute_corpus_metric_score(
|
|
|
|
|
|
|
|
|
def compute_quality_metrics(
|
|
|
- output_manifest_tsv_path: str,
|
|
|
- output_dir: str,
|
|
|
+ output_manifest_tsv_path: Path,
|
|
|
+ output_path: Path,
|
|
|
tgt_lang: str,
|
|
|
task: str,
|
|
|
- device: str,
|
|
|
- whisper_model_name: Optional[str] = "large",
|
|
|
- whisper_normalize_text_output: Optional[bool] = False,
|
|
|
- ref_text_col_name: Optional[str] = "ref_tgt_text",
|
|
|
- pred_text_col_name: Optional[str] = "pred_tgt_text",
|
|
|
- pred_audio_col_name: Optional[str] = "pred_tgt_audio",
|
|
|
-):
|
|
|
+ device: Device,
|
|
|
+ whisper_model_name: str = "large",
|
|
|
+ whisper_normalize_text_output: bool = False,
|
|
|
+ ref_text_col_name: str = "ref_tgt_text",
|
|
|
+ pred_text_col_name: str = "pred_tgt_text",
|
|
|
+ pred_audio_col_name: str = "pred_tgt_audio",
|
|
|
+) -> None:
|
|
|
"""Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
|
|
|
Args:
|
|
|
- output_manifest_tsv_path (str): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
|
|
|
- output_dir (str): Directory to write files with metrics
|
|
|
+ output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
|
|
|
+ output_path (Path): Directory to write files with metrics
|
|
|
tgt_lang (str): what language we evaluate on
|
|
|
task (str): Task we are currently evaluating for
|
|
|
- device (str): Device to use for inference
|
|
|
- whisper_model_name (str, Optional): Whisper model name. Defaults to "large".
|
|
|
- whisper_normalize_text_output (bool, Optional): Normalizes text output using whisper_normalizer if set to true
|
|
|
- ref_text_col_name (str, Optional): Column name in the tsv corresponding to reference target text
|
|
|
- pred_text_col_name (str, Optional): Column name in the tsv corresponding to predicted target text
|
|
|
- pred_audio_col_name (str, Optional): Column name in the tsv corresponding to predicted target audio.
|
|
|
+ device (Device): Device to use for inference
|
|
|
+ whisper_model_name (str): Whisper model name. Defaults to "large".
|
|
|
+ whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true
|
|
|
+ ref_text_col_name (str): Column name in the tsv corresponding to reference target text
|
|
|
+ pred_text_col_name (str): Column name in the tsv corresponding to predicted target text
|
|
|
+ pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio.
|
|
|
Setting this value to none will skip speech metrics
|
|
|
"""
|
|
|
df = pd.read_csv(
|
|
@@ -272,8 +280,8 @@ def compute_quality_metrics(
|
|
|
)
|
|
|
task = task.upper()
|
|
|
|
|
|
- if not Path(output_dir).exists():
|
|
|
- Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
+ if not output_path.exists():
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
|
|
|
metric = "chrF++" if task == "T2TT" else "bleu"
|
|
@@ -298,7 +306,7 @@ def compute_quality_metrics(
|
|
|
else "s2tt_bleu.json"
|
|
|
)
|
|
|
cur_task = "S2TT"
|
|
|
- with open((Path(output_dir) / filename).as_posix(), "w") as f:
|
|
|
+ with open((output_path / filename).as_posix(), "w") as f:
|
|
|
f.write(text_metric_json)
|
|
|
|
|
|
logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
|
|
@@ -317,7 +325,7 @@ def compute_quality_metrics(
|
|
|
whisper_normalize_text=True,
|
|
|
)
|
|
|
transcripts_df.to_csv(
|
|
|
- (Path(output_dir) / "whisper_audio_transcriptions.tsv"),
|
|
|
+ (output_path / "whisper_audio_transcriptions.tsv"),
|
|
|
sep="\t",
|
|
|
index=False,
|
|
|
encoding="utf-8",
|
|
@@ -331,27 +339,25 @@ def compute_quality_metrics(
|
|
|
)
|
|
|
|
|
|
with open(
|
|
|
- (Path(output_dir) / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
|
|
|
+ (output_path / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
|
|
|
"w",
|
|
|
) as f:
|
|
|
f.write(asr_bleu_normalized_json)
|
|
|
|
|
|
- with open((Path(output_dir) / filename).as_posix(), "w") as f:
|
|
|
+ with open((output_path / filename).as_posix(), "w") as f:
|
|
|
f.write(text_metric_json)
|
|
|
|
|
|
logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
|
|
|
|
|
|
if task == "ASR":
|
|
|
- asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
|
|
|
+ _, asr_error_rate_signature = compute_asr_error_rate(
|
|
|
hyp_text_series=df[pred_text_col_name],
|
|
|
ref_text_series=df[ref_text_col_name],
|
|
|
lang=tgt_lang,
|
|
|
whisper_normalize_text=whisper_normalize_text_output,
|
|
|
)
|
|
|
|
|
|
- with open((Path(output_dir) / "asr_error_rate.json").as_posix(), "w") as f:
|
|
|
+ with open((output_path / "asr_error_rate.json").as_posix(), "w") as f:
|
|
|
f.write(asr_error_rate_signature)
|
|
|
|
|
|
logger.info(f"ASR : {asr_error_rate_signature}")
|
|
|
-
|
|
|
- return
|