Forráskód Böngészése

Enabling slurm option in streaming evaluation cli (#189)

* Enabling slurm option in streaming evaluation cli

* Follow up changes based on #188

* update readme and help menu

* Addressing comments
Abinesh Ramakrishnan 1 éve
szülő
commit
e446d5ca56

+ 1 - 1
src/seamless_communication/cli/streaming/README.md

@@ -7,7 +7,7 @@ Evaluation can be run with the `streaming_evaluate` CLI.
 
 We use the `seamless_streaming_unity` for loading the speech encoder and T2U models, and `seamless_streaming_monotonic_decoder` for loading the text decoder for streaming evaluation. This is already set as defaults for the `streaming_evaluate` CLI, but can be overridden using the `--unity-model-name` and  `--monotonic-decoder-model-name` args if required.
 
-Note that the numbers in the paper use single precision floating point format (fp32) for evaluation by setting `--dtype fp32`.
+Note that the numbers in our paper use single precision floating point format (fp32) for evaluation by setting `--dtype fp32`. Also note that the results from running these evaluations might be slightly different from the results reported in our paper (which will be updated soon with the new results).
 
 ### S2TT:
 Set the task to `s2tt` for evaluating the speech-to-text translation part of the SeamlessStreaming model.

+ 5 - 15
src/seamless_communication/cli/streaming/evaluate.py

@@ -18,8 +18,7 @@ from seamless_communication.streaming.agents.seamless_streaming_s2st import (
 from seamless_communication.streaming.agents.seamless_streaming_s2t import (
     SeamlessStreamingS2TAgent,
 )
-from simuleval.evaluator import build_evaluator
-from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
+from simuleval.cli import evaluate
 
 logging.basicConfig(
     level=logging.INFO,
@@ -63,19 +62,18 @@ def main() -> None:
         max_len_b=100,
     )
 
-    EVALUATION_SYSTEM_LIST.clear()
     eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
     if args.task == "s2st":
         model_configs["min_unit_chunk_size"] = 50
         eval_configs["latency_metrics"] = "StartOffset EndOffset"
 
         if args.expressive:
-            EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
+            agent_class = SeamlessS2STAgent
         else:
-            EVALUATION_SYSTEM_LIST.append(SeamlessStreamingS2STAgent)
+            agent_class = SeamlessStreamingS2STAgent
     elif args.task in ["s2tt", "asr"]:
         assert args.expressive is False, "S2TT inference cannot be expressive."
-        EVALUATION_SYSTEM_LIST.append(SeamlessStreamingS2TAgent)
+        agent_class = SeamlessStreamingS2TAgent
         parser.add_argument(
             "--unity-model-name",
             type=str,
@@ -97,15 +95,7 @@ def main() -> None:
         dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
     )
 
-    system, args = build_system_args(
-        {**base_config, **model_configs, **eval_configs}, parser
-    )
-
-    if args.fp16:
-        logger.warn("--fp16 arg will be ignorned, use --dtype instead")
-
-    evaluator = build_evaluator(args)
-    evaluator(system)
+    evaluate(agent_class, {**base_config, **model_configs, **eval_configs}, parser)
 
 
 if __name__ == "__main__":

+ 1 - 1
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -132,7 +132,7 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ign
         parser.add_argument(
             "--vocoder-name",
             type=str,
-            help="Vocoder name.",
+            help="Vocoder name - vocoder_pretssel or vocoder_pretssel_16khz",
             default="vocoder_pretssel",
         )
         parser.add_argument(