Bläddra i källkod

Fix bug in eval script for S2ST, T2ST tasks. (#102)

Kaushik Ram Sadagopan 1 år sedan
förälder
incheckning
e3c40244e1
2 ändrade filer med 9 tillägg och 10 borttagningar
  1. 1 0
      setup.py
  2. 8 10
      src/seamless_communication/cli/eval_utils/compute_metrics.py

+ 1 - 0
setup.py

@@ -24,6 +24,7 @@ setup(
         "datasets",
         "fairseq2==0.2.*",
         "librosa",
+        "openai-whisper",
         "soundfile",
         "torchaudio",
         "tqdm",

+ 8 - 10
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -5,21 +5,21 @@
 # LICENSE file in the root directory of this source tree.
 
 import logging
-from pathlib import Path
-from typing import Tuple, Union
-
 import pandas as pd
 import whisper
+
 from fairseq2.typing import Device
 from jiwer import cer, wer
+from pathlib import Path
 from sacrebleu.metrics.base import Score, Signature
 from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.chrf import CHRF
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
+from typing import Tuple, Union
 from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
-from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 
 logging.basicConfig(
     level=logging.INFO,
@@ -305,7 +305,8 @@ def compute_quality_metrics(
                 else "s2tt_bleu.json"
             )
             cur_task = "S2TT"
-        with open((output_path / filename).as_posix(), "w") as f:
+
+        with open(output_path / filename, "w") as f:
             f.write(text_metric_json)
 
         logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
@@ -338,14 +339,11 @@ def compute_quality_metrics(
         )
 
         with open(
-            (output_path / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
+            output_path / f"{task.lower()}_asr_bleu_normalized.json",
             "w",
         ) as f:
             f.write(asr_bleu_normalized_json)
 
-        with open((output_path / filename).as_posix(), "w") as f:
-            f.write(text_metric_json)
-
         logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
 
     if task == "ASR":
@@ -356,7 +354,7 @@ def compute_quality_metrics(
             whisper_normalize_text=whisper_normalize_text_output,
         )
 
-        with open((output_path / "asr_error_rate.json").as_posix(), "w") as f:
+        with open(output_path / "asr_error_rate.json", "w") as f:
             f.write(asr_error_rate_signature)
 
         logger.info(f"ASR : {asr_error_rate_signature}")