Procházet zdrojové kódy

fix cmn ASR eval in SeamlessExpressive

hygong před 1 rokem
rodič
revize
50e403542d

+ 149 - 4
src/seamless_communication/cli/expressivity/evaluate/run_asr_bleu.py

@@ -4,12 +4,158 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
+from pathlib import Path
+from typing import Optional, Tuple, Union
+import logging
+import pandas as pd
+import whisper
+from whisper import Whisper
+from sacrebleu.metrics.base import Score, Signature
 from fire import Fire
 from seamless_communication.cli.eval_utils.compute_metrics import (
-    compute_quality_metrics,
+    init_whisper_model,
+    compute_corpus_metric_score,
+    transcribe_series,
 )
 from fairseq2.typing import Device
-from pathlib import Path
+
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+def compute_asr_bleu(
+    audio_paths_series: pd.Series,
+    ref_text_series: pd.Series,
+    lang: str,
+    asr_model: Whisper,
+    whisper_normalize_text: bool = True,
+    beam_size: int = 1,
+    temperature: float = 0.0,
+    return_transcriptions: bool = True,
+) -> Tuple[Score, Signature, pd.DataFrame]:
+    if lang == "cmn":
+        try:
+            import chinese_converter
+        except ImportError:
+            raise ImportError(
+                "Please install chinese_converter: pip install chinese_converter"
+            )
+    audio_transcriptions = transcribe_series(
+        audio_paths_series,
+        asr_model,
+        audio_lang=lang,
+        beam_size=beam_size,
+        temperature=temperature,
+    )
+    if lang == "cmn":
+        audio_transcriptions = pd.Series([chinese_converter.to_simplified(text) for text in audio_transcriptions])
+        ref_text_series = pd.Series([chinese_converter.to_simplified(text) for text in ref_text_series])
+    asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
+        audio_transcriptions, ref_text_series, lang, whisper_normalize_text
+    )
+    asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
+    asr_bleu_signature.info["whisper_asr_temperature"] = temperature
+    asr_bleu_signature.info["whisper_asr_language"] = lang
+
+    transcript_df = None
+    if return_transcriptions:
+        transcript_df = pd.concat(
+            [
+                audio_paths_series,
+                audio_transcriptions,
+                ref_text_series,
+            ],
+            axis=1,
+            keys=["audio", "transcript", "reference"],
+        )
+    return asr_bleu, asr_bleu_signature, transcript_df
+
+
+def compute_s2st_quality_metrics(
+    output_manifest_tsv_path: Path,
+    output_path: Path,
+    tgt_lang: str,
+    device: Device,
+    whisper_model_name: str = "large",
+    whisper_normalize_text_output: bool = False,
+    ref_text_col_name: str = "ref_tgt_text",
+    pred_text_col_name: Optional[str] = "pred_tgt_text",
+    pred_audio_col_name: str = "pred_tgt_audio",
+) -> str:
+ 
+    df = pd.read_csv(
+        output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
+    )
+
+    if not output_path.exists():
+        output_path.mkdir(parents=True, exist_ok=True)
+
+    metric = "bleu"
+
+    text_metric, text_metric_signature = compute_corpus_metric_score(
+        hyp_text_series=df[pred_text_col_name],
+        ref_text_series=df[ref_text_col_name],
+        lang=tgt_lang,
+        whisper_normalize_text=whisper_normalize_text_output,
+        metric=metric,
+    )
+    text_metric_json = text_metric.format(
+        signature=text_metric_signature.format(), is_json=True
+    )
+
+    filename = (
+        "s2tt_bleu_normalized.json"
+        if whisper_normalize_text_output
+        else "s2tt_bleu.json"
+    )
+    cur_task = "S2TT"
+
+    with open(output_path / filename, "w") as f:
+        f.write(text_metric_json)
+
+    logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
+
+    whisper_model = init_whisper_model(device, whisper_model_name)
+    (
+        asr_bleu_normalized,
+        asr_bleu_normalized_signature,
+        transcripts_df,
+    ) = compute_asr_bleu(
+        audio_paths_series=df[pred_audio_col_name],
+        ref_text_series=df[ref_text_col_name],
+        lang=tgt_lang,
+        asr_model=whisper_model,
+        whisper_normalize_text=True,
+    )
+    transcripts_df.to_csv(
+        (output_path / "whisper_audio_transcriptions.tsv"),
+        sep="\t",
+        index=False,
+        encoding="utf-8",
+        escapechar="\\",
+    )
+
+    asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
+
+    asr_bleu_normalized_json = asr_bleu_normalized.format(
+        signature=asr_bleu_normalized_signature.format(), is_json=True
+    )
+    filename = "s2st_asr_bleu_normalized.json"
+
+    with open(
+        output_path / filename,
+        "w",
+    ) as f:
+        f.write(asr_bleu_normalized_json)
+
+    logger.info(f"S2ST ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
+
+    return filename
 
 
 def run_asr_bleu_expressive_model(
@@ -17,11 +163,10 @@ def run_asr_bleu_expressive_model(
     generate_tsv_filename: str,
     tgt_lang: str,
 ):
-    compute_quality_metrics(
+    compute_s2st_quality_metrics(
         f"{generation_dir_path}/{generate_tsv_filename}",
         Path(generation_dir_path),
         tgt_lang,
-        "S2ST",
         device=Device("cuda"),
         ref_text_col_name="tgt_text",
         pred_text_col_name="s2t_out",