ソースを参照

Adding seamless quality scorer (#167)

* streaming cli improvements

* Streaming Evaluate CLI

* bump simuleval version requirement

* rebase and update

* revert simuleval version bump

* Remove old scripts

* revert simuleval dependency version change

* mypy issue in pretssel_vocoder

* setting cli defaults

* Logging vocoder load

* Implement new seamless quality scorer

* Prevent expressive option in s2tt

* change max_len_a to 1

* remove old quality scorer

* Add license

* Fix class method typing

---------

Co-authored-by: ibanesh <3632454+ibanesh@users.noreply.github.com>
Xutai Ma 1 年間 前
コミット
7c64fef0d4

+ 9 - 5
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -16,7 +16,7 @@ from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.chrf import CHRF
 from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
 from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
@@ -257,9 +257,9 @@ def compute_quality_metrics(
     whisper_model_name: str = "large",
     whisper_normalize_text_output: bool = False,
     ref_text_col_name: str = "ref_tgt_text",
-    pred_text_col_name: str = "pred_tgt_text",
+    pred_text_col_name: Optional[str] = "pred_tgt_text",
     pred_audio_col_name: str = "pred_tgt_audio",
-) -> None:
+) -> str:
     """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
     Args:
         output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
@@ -337,9 +337,10 @@ def compute_quality_metrics(
         asr_bleu_normalized_json = asr_bleu_normalized.format(
             signature=asr_bleu_normalized_signature.format(), is_json=True
         )
+        filename = f"{task.lower()}_asr_bleu_normalized.json"
 
         with open(
-            output_path / f"{task.lower()}_asr_bleu_normalized.json",
+            output_path / filename,
             "w",
         ) as f:
             f.write(asr_bleu_normalized_json)
@@ -353,8 +354,11 @@ def compute_quality_metrics(
             lang=tgt_lang,
             whisper_normalize_text=whisper_normalize_text_output,
         )
+        filename = "asr_error_rate.json"
 
-        with open(output_path / "asr_error_rate.json", "w") as f:
+        with open(output_path / filename, "w") as f:
             f.write(asr_error_rate_signature)
 
         logger.info(f"ASR : {asr_error_rate_signature}")
+
+    return filename

+ 10 - 22
src/seamless_communication/cli/streaming/evaluate.py

@@ -7,14 +7,13 @@
 import argparse
 
 from fairseq2.assets import asset_store, download_manager
-from seamless_communication.cli.eval_utils import get_tokenizer
-from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
-    SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
-)
 from seamless_communication.streaming.agents.mma_m4t_s2st import (
     MonotonicM4TS2STAgent,
     SeamlessS2STAgent,
 )
+from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
+    SeamlessQualityScorer,
+)
 from seamless_communication.streaming.agents.mma_m4t_s2t import MonotonicM4TS2TAgent
 from simuleval.evaluator import build_evaluator
 from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
@@ -63,24 +62,22 @@ def main() -> None:
         model_configs.update(dict(fp16=True))
 
     EVALUATION_SYSTEM_LIST.clear()
+    eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
     if args.task == "s2st":
         model_configs.update(
             dict(
                 min_unit_chunk_size=50,
             )
         )
-        eval_configs = dict(
-            quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
-            latency_metrics="StartOffset EndOffset",
-            whisper_model_size="large-v2",
-            normalize_asr_bleu_references=True,
-        )
+        eval_configs["latency_metrics"] = "StartOffset EndOffset"
+
         if args.expressive:
             EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
             model_configs.update(dict(vocoder_name="vocoder_pretssel"))
         else:
             EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2STAgent)
     elif args.task == "s2tt":
+        assert args.expressive is False, "S2TT inference cannot be expressive."
         EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2TAgent)
         parser.add_argument(
             "--unity-model-name",
@@ -88,24 +85,15 @@ def main() -> None:
             help="Unity model name.",
             default="seamless_streaming_unity",
         )
-        parser.add_argument(
-            "--tgt-lang",
-            default="eng",
-            type=str,
-            help="Target language to translate/transcribe into.",
-        )
         args, _ = parser.parse_known_args()
         asset_card = asset_store.retrieve_card(name=args.unity_model_name)
         tokenizer_uri = asset_card.field("tokenizer").as_uri()
         tokenizer_path = download_manager.download_tokenizer(
             tokenizer_uri, asset_card.name, force=False, progress=True
         )
-        eval_configs = dict(
-            sacrebleu_tokenizer=get_tokenizer(args.tgt_lang),
-            eval_latency_unit="spm",
-            eval_latency_spm_model=tokenizer_path,
-            latency_metrics="AL LAAL",
-        )
+        eval_configs["latency_metrics"] = "AL LAAL"
+        eval_configs["eval_latency_unit"] = "spm"
+        eval_configs["eval_latency_spm_model"] = tokenizer_path
 
     base_config = dict(
         dataloader="fairseq2_s2tt",

+ 146 - 0
src/seamless_communication/cli/streaming/scorers/seamless_quality_scorer.py

@@ -0,0 +1,146 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+import pandas
+from fairseq2.typing import Device
+from pathlib import Path
+from typing import Optional
+import json
+from argparse import ArgumentParser, Namespace
+from typing import Dict
+
+from simuleval.evaluator.scorers.quality_scorer import (
+    register_quality_scorer,
+    QualityScorer,
+)
+
+from simuleval.evaluator.instance import LogInstance
+
+from seamless_communication.cli.eval_utils import (
+    compute_quality_metrics,
+)
+
+
+@register_quality_scorer("SEAMLESS_QUALITY_SCORER")
+class SeamlessQualityScorer(QualityScorer):  # type: ignore
+    def __init__(
+        self,
+        tgt_lang: str,
+        task: str,
+        output_dir: str,
+        device: Device = "cuda:0",
+        whisper_model_name: str = "large",
+        whisper_normalize_text_output: Optional[bool] = None,
+        ref_text_col_name: str = "ref_tgt_text",
+        pred_text_col_name: str = "pred_tgt_text",
+        pred_audio_col_name: str = "pred_tgt_audio",
+    ) -> None:
+        super().__init__()
+        self.tgt_lang = tgt_lang
+        self.task = task.upper()
+        self.device = device
+        self.output_dir = Path(output_dir)
+        self.whisper_model_name = whisper_model_name
+        self.whisper_normalize_text_output = whisper_normalize_text_output
+        if self.whisper_normalize_text_output is None:
+            self.whisper_normalize_text_output = (
+                False if self.task in ["S2TT", "S2ST", "T2TT"] else True
+            )
+        self.ref_text_col_name = ref_text_col_name
+        self.pred_text_col_name = pred_text_col_name
+        self.pred_audio_col_name = pred_audio_col_name
+
+    def __call__(self, instances: Dict[int, LogInstance]) -> float:
+        references = [ins.reference for ins in instances.values()]
+        df = pandas.DataFrame({self.ref_text_col_name: references})
+        if self.task in ["ASR", "S2TT", "T2TT"]:
+            predictions = [ins.prediction for ins in instances.values()]
+            df[self.pred_text_col_name] = predictions
+        else:
+            predictions = [ins.prediction for ins in instances.values()]
+            df[self.pred_audio_col_name] = predictions
+
+        df.to_csv(
+            self.output_dir / "results.tsv",
+            sep="\t",
+            quoting=3,
+            encoding="utf-8",
+        )
+        filename = compute_quality_metrics(
+            self.output_dir / "results.tsv",
+            self.output_dir,
+            self.tgt_lang,
+            self.task,
+            self.device,
+            self.whisper_model_name,
+            self.whisper_normalize_text_output,
+            self.ref_text_col_name,
+            self.pred_text_col_name if self.task in ["ASR", "S2TT", "T2TT"] else None,
+            self.pred_audio_col_name,
+        )
+
+        with open(self.output_dir / filename, "r") as f:
+            corpus_metric_score = json.load(f)["score"]
+
+        return corpus_metric_score  # type: ignore[no-any-return]
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        try:
+            parser.add_argument(
+                "--task", type=str, help="Task to evaluate", required=True
+            )
+            parser.add_argument(
+                "--tgt-lang",
+                type=str,
+                help="Target language to translate/transcribe into.",
+                required=True,
+            )
+        except:
+            pass
+
+        parser.add_argument(
+            "--whisper-model-name", type=str, help="Whisper model name", default="large"
+        )
+        parser.add_argument(
+            "--whisper-normalize-text-output",
+            action="store_true",
+            help="Normalize text output",
+            default=None,
+        )
+        parser.add_argument(
+            "--ref-text-col-name",
+            type=str,
+            help="Reference text column name",
+            default="ref_tgt_text",
+        )
+        parser.add_argument(
+            "--pred-text-col-name",
+            type=str,
+            help="Prediction text column name",
+            default="pred_tgt_text",
+        )
+        parser.add_argument(
+            "--pred-audio-col-name",
+            type=str,
+            help="Prediction audio column name",
+            default="pred_tgt_audio",
+        )
+
+    @classmethod
+    def from_args(cls, args: Namespace) -> SeamlessQualityScorer:
+        return cls(
+            tgt_lang=args.tgt_lang,
+            task=args.task,
+            output_dir=args.output,
+            device=getattr(args, "device", "cpu"),
+            whisper_model_name=args.whisper_model_name,
+            whisper_normalize_text_output=args.whisper_normalize_text_output,
+            ref_text_col_name=args.ref_text_col_name,
+            pred_text_col_name=args.pred_text_col_name,
+            pred_audio_col_name=args.pred_audio_col_name,
+        )

+ 0 - 84
src/seamless_communication/cli/streaming/scorers/seamless_whisper_asr_bleu.py

@@ -1,84 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-from __future__ import annotations
-
-from argparse import ArgumentParser, Namespace
-from typing import Dict, List
-
-from sacrebleu.metrics.bleu import BLEU
-from seamless_communication.cli.eval_utils import LANG3_LANG2, get_tokenizer
-from simuleval.evaluator.instance import LogInstance
-from simuleval.evaluator.scorers.quality_scorer import (
-    WhisperASRSacreBLEUScorer,
-    register_quality_scorer,
-)
-from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
-
-
-def normalize_text_whisper(sentences: List[str], lang: str) -> List[str]:
-    if lang in ["en", "eng"]:
-        normalizer = EnglishTextNormalizer()
-    else:
-        normalizer = BasicTextNormalizer()
-    normalized_sentences = []
-    for text in sentences:
-        normalized_sentences.append(normalizer(text))
-    return normalized_sentences
-
-
-@register_quality_scorer("SEAMLESS_WHISPER_ASR_BLEU")
-class SeamlessWhisperASRSacreBLEUScorer(WhisperASRSacreBLEUScorer):  # type: ignore
-    def __init__(
-        self,
-        tokenizer: str = "13a",
-        target_lang: str = "en",
-        model_size: str = "base",
-        lowercase: bool = False,
-        remove_punctuations: bool = False,
-        normalize_asr_bleu_references: bool = False,
-    ) -> None:
-        super().__init__()
-        self.tokenizer = tokenizer
-        self.target_lang = target_lang
-        self.model_size = model_size
-        self.lowercase = lowercase
-        self.remove_punctuations = remove_punctuations
-        self.normalize_asr_bleu_references = normalize_asr_bleu_references
-
-    def __call__(self, instances: Dict[int, LogInstance]) -> float:
-        transcripts = self.asr_transcribe(instances)
-        references = [[ins.reference for ins in instances.values()]]
-
-        if self.normalize_asr_bleu_references:
-            transcripts = normalize_text_whisper(transcripts, self.target_lang)
-            references = [normalize_text_whisper(references[0], self.target_lang)]
-
-        score = (
-            BLEU(tokenize=self.tokenizer).corpus_score(transcripts, references).score
-        )
-        return score  # type: ignore[no-any-return]
-
-    @staticmethod
-    def add_args(parser: ArgumentParser) -> None:
-        WhisperASRSacreBLEUScorer.add_args(parser)
-        parser.add_argument(
-            "--normalize-asr-bleu-references",
-            action="store_true",
-            help="Normalize asr transcript and reference",
-        )
-
-    @classmethod
-    def from_args(cls, args: Namespace) -> SeamlessWhisperASRSacreBLEUScorer:
-        sacrebleu_tokenizer = get_tokenizer(args.tgt_lang)
-        tgt_lang_2ltr = LANG3_LANG2[args.tgt_lang]
-        return cls(
-            sacrebleu_tokenizer,
-            tgt_lang_2ltr,
-            args.whisper_model_size,
-            args.transcript_lowercase,
-            args.transcript_non_punctuation,
-            args.normalize_asr_bleu_references,
-        )