|
@@ -5,21 +5,21 @@
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import logging
|
|
|
-from pathlib import Path
|
|
|
-from typing import Tuple, Union
|
|
|
-
|
|
|
import pandas as pd
|
|
|
import whisper
|
|
|
+
|
|
|
from fairseq2.typing import Device
|
|
|
from jiwer import cer, wer
|
|
|
+from pathlib import Path
|
|
|
from sacrebleu.metrics.base import Score, Signature
|
|
|
from sacrebleu.metrics.bleu import BLEU
|
|
|
from sacrebleu.metrics.chrf import CHRF
|
|
|
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
|
|
|
from tqdm import tqdm
|
|
|
+from typing import Tuple, Union
|
|
|
from whisper import Whisper
|
|
|
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
|
|
|
|
|
|
-from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
@@ -305,7 +305,8 @@ def compute_quality_metrics(
|
|
|
else "s2tt_bleu.json"
|
|
|
)
|
|
|
cur_task = "S2TT"
|
|
|
- with open((output_path / filename).as_posix(), "w") as f:
|
|
|
+
|
|
|
+ with open(output_path / filename, "w") as f:
|
|
|
f.write(text_metric_json)
|
|
|
|
|
|
logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
|
|
@@ -338,14 +339,11 @@ def compute_quality_metrics(
|
|
|
)
|
|
|
|
|
|
with open(
|
|
|
- (output_path / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
|
|
|
+ output_path / f"{task.lower()}_asr_bleu_normalized.json",
|
|
|
"w",
|
|
|
) as f:
|
|
|
f.write(asr_bleu_normalized_json)
|
|
|
|
|
|
- with open((output_path / filename).as_posix(), "w") as f:
|
|
|
- f.write(text_metric_json)
|
|
|
-
|
|
|
logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
|
|
|
|
|
|
if task == "ASR":
|
|
@@ -356,7 +354,7 @@ def compute_quality_metrics(
|
|
|
whisper_normalize_text=whisper_normalize_text_output,
|
|
|
)
|
|
|
|
|
|
- with open((output_path / "asr_error_rate.json").as_posix(), "w") as f:
|
|
|
+ with open(output_path / "asr_error_rate.json", "w") as f:
|
|
|
f.write(asr_error_rate_signature)
|
|
|
|
|
|
logger.info(f"ASR : {asr_error_rate_signature}")
|