瀏覽代碼

Adding seamless quality scorer (#167)

* streaming cli improvements

* Streaming Evaluate CLI

* bump simuleval version requirement

* rebase and update

* revert simuleval version bump

* Remove old scripts

* revert simuleval dependency version change

* mypy issue in pretssel_vocoder

* setting cli defaults

* Logging vocoder load

* Implement new seamless quality scorer

* Prevent expressive option in s2tt

* change max_len_a to 1

* remove old quality scorer

* Add license

* Fix class method typing

---------

Co-authored-by: ibanesh <3632454+ibanesh@users.noreply.github.com>
Xutai Ma 1 年之前
父節點
當前提交
7c64fef0d4

+ 9 - 5
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -16,7 +16,7 @@ from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.chrf import CHRF
 from sacrebleu.metrics.chrf import CHRF
 from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
 from tqdm import tqdm
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
 from whisper import Whisper
 from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
 
@@ -257,9 +257,9 @@ def compute_quality_metrics(
     whisper_model_name: str = "large",
     whisper_model_name: str = "large",
     whisper_normalize_text_output: bool = False,
     whisper_normalize_text_output: bool = False,
     ref_text_col_name: str = "ref_tgt_text",
     ref_text_col_name: str = "ref_tgt_text",
-    pred_text_col_name: str = "pred_tgt_text",
+    pred_text_col_name: Optional[str] = "pred_tgt_text",
     pred_audio_col_name: str = "pred_tgt_audio",
     pred_audio_col_name: str = "pred_tgt_audio",
-) -> None:
+) -> str:
     """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
     """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
     Args:
     Args:
         output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
         output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
@@ -337,9 +337,10 @@ def compute_quality_metrics(
         asr_bleu_normalized_json = asr_bleu_normalized.format(
         asr_bleu_normalized_json = asr_bleu_normalized.format(
             signature=asr_bleu_normalized_signature.format(), is_json=True
             signature=asr_bleu_normalized_signature.format(), is_json=True
         )
         )
+        filename = f"{task.lower()}_asr_bleu_normalized.json"
 
 
         with open(
         with open(
-            output_path / f"{task.lower()}_asr_bleu_normalized.json",
+            output_path / filename,
             "w",
             "w",
         ) as f:
         ) as f:
             f.write(asr_bleu_normalized_json)
             f.write(asr_bleu_normalized_json)
@@ -353,8 +354,11 @@ def compute_quality_metrics(
             lang=tgt_lang,
             lang=tgt_lang,
             whisper_normalize_text=whisper_normalize_text_output,
             whisper_normalize_text=whisper_normalize_text_output,
         )
         )
+        filename = "asr_error_rate.json"
 
 
-        with open(output_path / "asr_error_rate.json", "w") as f:
+        with open(output_path / filename, "w") as f:
             f.write(asr_error_rate_signature)
             f.write(asr_error_rate_signature)
 
 
         logger.info(f"ASR : {asr_error_rate_signature}")
         logger.info(f"ASR : {asr_error_rate_signature}")
+
+    return filename

+ 10 - 22
src/seamless_communication/cli/streaming/evaluate.py

@@ -7,14 +7,13 @@
 import argparse
 import argparse
 
 
 from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets import asset_store, download_manager
-from seamless_communication.cli.eval_utils import get_tokenizer
-from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
-    SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
-)
 from seamless_communication.streaming.agents.mma_m4t_s2st import (
 from seamless_communication.streaming.agents.mma_m4t_s2st import (
     MonotonicM4TS2STAgent,
     MonotonicM4TS2STAgent,
     SeamlessS2STAgent,
     SeamlessS2STAgent,
 )
 )
+from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
+    SeamlessQualityScorer,
+)
 from seamless_communication.streaming.agents.mma_m4t_s2t import MonotonicM4TS2TAgent
 from seamless_communication.streaming.agents.mma_m4t_s2t import MonotonicM4TS2TAgent
 from simuleval.evaluator import build_evaluator
 from simuleval.evaluator import build_evaluator
 from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
 from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
@@ -63,24 +62,22 @@ def main() -> None:
         model_configs.update(dict(fp16=True))
         model_configs.update(dict(fp16=True))
 
 
     EVALUATION_SYSTEM_LIST.clear()
     EVALUATION_SYSTEM_LIST.clear()
+    eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
     if args.task == "s2st":
     if args.task == "s2st":
         model_configs.update(
         model_configs.update(
             dict(
             dict(
                 min_unit_chunk_size=50,
                 min_unit_chunk_size=50,
             )
             )
         )
         )
-        eval_configs = dict(
-            quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
-            latency_metrics="StartOffset EndOffset",
-            whisper_model_size="large-v2",
-            normalize_asr_bleu_references=True,
-        )
+        eval_configs["latency_metrics"] = "StartOffset EndOffset"
+
         if args.expressive:
         if args.expressive:
             EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
             EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
             model_configs.update(dict(vocoder_name="vocoder_pretssel"))
             model_configs.update(dict(vocoder_name="vocoder_pretssel"))
         else:
         else:
             EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2STAgent)
             EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2STAgent)
     elif args.task == "s2tt":
     elif args.task == "s2tt":
+        assert args.expressive is False, "S2TT inference cannot be expressive."
         EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2TAgent)
         EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2TAgent)
         parser.add_argument(
         parser.add_argument(
             "--unity-model-name",
             "--unity-model-name",
@@ -88,24 +85,15 @@ def main() -> None:
             help="Unity model name.",
             help="Unity model name.",
             default="seamless_streaming_unity",
             default="seamless_streaming_unity",
         )
         )
-        parser.add_argument(
-            "--tgt-lang",
-            default="eng",
-            type=str,
-            help="Target language to translate/transcribe into.",
-        )
         args, _ = parser.parse_known_args()
         args, _ = parser.parse_known_args()
         asset_card = asset_store.retrieve_card(name=args.unity_model_name)
         asset_card = asset_store.retrieve_card(name=args.unity_model_name)
         tokenizer_uri = asset_card.field("tokenizer").as_uri()
         tokenizer_uri = asset_card.field("tokenizer").as_uri()
         tokenizer_path = download_manager.download_tokenizer(
         tokenizer_path = download_manager.download_tokenizer(
             tokenizer_uri, asset_card.name, force=False, progress=True
             tokenizer_uri, asset_card.name, force=False, progress=True
         )
         )
-        eval_configs = dict(
-            sacrebleu_tokenizer=get_tokenizer(args.tgt_lang),
-            eval_latency_unit="spm",
-            eval_latency_spm_model=tokenizer_path,
-            latency_metrics="AL LAAL",
-        )
+        eval_configs["latency_metrics"] = "AL LAAL"
+        eval_configs["eval_latency_unit"] = "spm"
+        eval_configs["eval_latency_spm_model"] = tokenizer_path
 
 
     base_config = dict(
     base_config = dict(
         dataloader="fairseq2_s2tt",
         dataloader="fairseq2_s2tt",

+ 146 - 0
src/seamless_communication/cli/streaming/scorers/seamless_quality_scorer.py

@@ -0,0 +1,146 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+import pandas
+from fairseq2.typing import Device
+from pathlib import Path
+from typing import Optional
+import json
+from argparse import ArgumentParser, Namespace
+from typing import Dict
+
+from simuleval.evaluator.scorers.quality_scorer import (
+    register_quality_scorer,
+    QualityScorer,
+)
+
+from simuleval.evaluator.instance import LogInstance
+
+from seamless_communication.cli.eval_utils import (
+    compute_quality_metrics,
+)
+
+
+@register_quality_scorer("SEAMLESS_QUALITY_SCORER")
+class SeamlessQualityScorer(QualityScorer):  # type: ignore
+    def __init__(
+        self,
+        tgt_lang: str,
+        task: str,
+        output_dir: str,
+        device: Device = "cuda:0",
+        whisper_model_name: str = "large",
+        whisper_normalize_text_output: Optional[bool] = None,
+        ref_text_col_name: str = "ref_tgt_text",
+        pred_text_col_name: str = "pred_tgt_text",
+        pred_audio_col_name: str = "pred_tgt_audio",
+    ) -> None:
+        super().__init__()
+        self.tgt_lang = tgt_lang
+        self.task = task.upper()
+        self.device = device
+        self.output_dir = Path(output_dir)
+        self.whisper_model_name = whisper_model_name
+        self.whisper_normalize_text_output = whisper_normalize_text_output
+        if self.whisper_normalize_text_output is None:
+            self.whisper_normalize_text_output = (
+                False if self.task in ["S2TT", "S2ST", "T2TT"] else True
+            )
+        self.ref_text_col_name = ref_text_col_name
+        self.pred_text_col_name = pred_text_col_name
+        self.pred_audio_col_name = pred_audio_col_name
+
+    def __call__(self, instances: Dict[int, LogInstance]) -> float:
+        references = [ins.reference for ins in instances.values()]
+        df = pandas.DataFrame({self.ref_text_col_name: references})
+        if self.task in ["ASR", "S2TT", "T2TT"]:
+            predictions = [ins.prediction for ins in instances.values()]
+            df[self.pred_text_col_name] = predictions
+        else:
+            predictions = [ins.prediction for ins in instances.values()]
+            df[self.pred_audio_col_name] = predictions
+
+        df.to_csv(
+            self.output_dir / "results.tsv",
+            sep="\t",
+            quoting=3,
+            encoding="utf-8",
+        )
+        filename = compute_quality_metrics(
+            self.output_dir / "results.tsv",
+            self.output_dir,
+            self.tgt_lang,
+            self.task,
+            self.device,
+            self.whisper_model_name,
+            self.whisper_normalize_text_output,
+            self.ref_text_col_name,
+            self.pred_text_col_name if self.task in ["ASR", "S2TT", "T2TT"] else None,
+            self.pred_audio_col_name,
+        )
+
+        with open(self.output_dir / filename, "r") as f:
+            corpus_metric_score = json.load(f)["score"]
+
+        return corpus_metric_score  # type: ignore[no-any-return]
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        try:
+            parser.add_argument(
+                "--task", type=str, help="Task to evaluate", required=True
+            )
+            parser.add_argument(
+                "--tgt-lang",
+                type=str,
+                help="Target language to translate/transcribe into.",
+                required=True,
+            )
+        except:
+            pass
+
+        parser.add_argument(
+            "--whisper-model-name", type=str, help="Whisper model name", default="large"
+        )
+        parser.add_argument(
+            "--whisper-normalize-text-output",
+            action="store_true",
+            help="Normalize text output",
+            default=None,
+        )
+        parser.add_argument(
+            "--ref-text-col-name",
+            type=str,
+            help="Reference text column name",
+            default="ref_tgt_text",
+        )
+        parser.add_argument(
+            "--pred-text-col-name",
+            type=str,
+            help="Prediction text column name",
+            default="pred_tgt_text",
+        )
+        parser.add_argument(
+            "--pred-audio-col-name",
+            type=str,
+            help="Prediction audio column name",
+            default="pred_tgt_audio",
+        )
+
+    @classmethod
+    def from_args(cls, args: Namespace) -> SeamlessQualityScorer:
+        return cls(
+            tgt_lang=args.tgt_lang,
+            task=args.task,
+            output_dir=args.output,
+            device=getattr(args, "device", "cpu"),
+            whisper_model_name=args.whisper_model_name,
+            whisper_normalize_text_output=args.whisper_normalize_text_output,
+            ref_text_col_name=args.ref_text_col_name,
+            pred_text_col_name=args.pred_text_col_name,
+            pred_audio_col_name=args.pred_audio_col_name,
+        )

+ 0 - 84
src/seamless_communication/cli/streaming/scorers/seamless_whisper_asr_bleu.py

@@ -1,84 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-from __future__ import annotations
-
-from argparse import ArgumentParser, Namespace
-from typing import Dict, List
-
-from sacrebleu.metrics.bleu import BLEU
-from seamless_communication.cli.eval_utils import LANG3_LANG2, get_tokenizer
-from simuleval.evaluator.instance import LogInstance
-from simuleval.evaluator.scorers.quality_scorer import (
-    WhisperASRSacreBLEUScorer,
-    register_quality_scorer,
-)
-from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
-
-
-def normalize_text_whisper(sentences: List[str], lang: str) -> List[str]:
-    if lang in ["en", "eng"]:
-        normalizer = EnglishTextNormalizer()
-    else:
-        normalizer = BasicTextNormalizer()
-    normalized_sentences = []
-    for text in sentences:
-        normalized_sentences.append(normalizer(text))
-    return normalized_sentences
-
-
-@register_quality_scorer("SEAMLESS_WHISPER_ASR_BLEU")
-class SeamlessWhisperASRSacreBLEUScorer(WhisperASRSacreBLEUScorer):  # type: ignore
-    def __init__(
-        self,
-        tokenizer: str = "13a",
-        target_lang: str = "en",
-        model_size: str = "base",
-        lowercase: bool = False,
-        remove_punctuations: bool = False,
-        normalize_asr_bleu_references: bool = False,
-    ) -> None:
-        super().__init__()
-        self.tokenizer = tokenizer
-        self.target_lang = target_lang
-        self.model_size = model_size
-        self.lowercase = lowercase
-        self.remove_punctuations = remove_punctuations
-        self.normalize_asr_bleu_references = normalize_asr_bleu_references
-
-    def __call__(self, instances: Dict[int, LogInstance]) -> float:
-        transcripts = self.asr_transcribe(instances)
-        references = [[ins.reference for ins in instances.values()]]
-
-        if self.normalize_asr_bleu_references:
-            transcripts = normalize_text_whisper(transcripts, self.target_lang)
-            references = [normalize_text_whisper(references[0], self.target_lang)]
-
-        score = (
-            BLEU(tokenize=self.tokenizer).corpus_score(transcripts, references).score
-        )
-        return score  # type: ignore[no-any-return]
-
-    @staticmethod
-    def add_args(parser: ArgumentParser) -> None:
-        WhisperASRSacreBLEUScorer.add_args(parser)
-        parser.add_argument(
-            "--normalize-asr-bleu-references",
-            action="store_true",
-            help="Normalize asr transcript and reference",
-        )
-
-    @classmethod
-    def from_args(cls, args: Namespace) -> SeamlessWhisperASRSacreBLEUScorer:
-        sacrebleu_tokenizer = get_tokenizer(args.tgt_lang)
-        tgt_lang_2ltr = LANG3_LANG2[args.tgt_lang]
-        return cls(
-            sacrebleu_tokenizer,
-            tgt_lang_2ltr,
-            args.whisper_model_size,
-            args.transcript_lowercase,
-            args.transcript_non_punctuation,
-            args.normalize_asr_bleu_references,
-        )