|
@@ -18,8 +18,7 @@ from seamless_communication.streaming.agents.seamless_streaming_s2st import (
|
|
from seamless_communication.streaming.agents.seamless_streaming_s2t import (
|
|
from seamless_communication.streaming.agents.seamless_streaming_s2t import (
|
|
SeamlessStreamingS2TAgent,
|
|
SeamlessStreamingS2TAgent,
|
|
)
|
|
)
|
|
-from simuleval.evaluator import build_evaluator
|
|
|
|
-from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
|
|
|
|
|
|
+from simuleval.cli import evaluate
|
|
|
|
|
|
logging.basicConfig(
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
level=logging.INFO,
|
|
@@ -63,19 +62,18 @@ def main() -> None:
|
|
max_len_b=100,
|
|
max_len_b=100,
|
|
)
|
|
)
|
|
|
|
|
|
- EVALUATION_SYSTEM_LIST.clear()
|
|
|
|
eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
|
|
eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
|
|
if args.task == "s2st":
|
|
if args.task == "s2st":
|
|
model_configs["min_unit_chunk_size"] = 50
|
|
model_configs["min_unit_chunk_size"] = 50
|
|
eval_configs["latency_metrics"] = "StartOffset EndOffset"
|
|
eval_configs["latency_metrics"] = "StartOffset EndOffset"
|
|
|
|
|
|
if args.expressive:
|
|
if args.expressive:
|
|
- EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
|
|
|
|
|
|
+ agent_class = SeamlessS2STAgent
|
|
else:
|
|
else:
|
|
- EVALUATION_SYSTEM_LIST.append(SeamlessStreamingS2STAgent)
|
|
|
|
|
|
+ agent_class = SeamlessStreamingS2STAgent
|
|
elif args.task in ["s2tt", "asr"]:
|
|
elif args.task in ["s2tt", "asr"]:
|
|
assert args.expressive is False, "S2TT inference cannot be expressive."
|
|
assert args.expressive is False, "S2TT inference cannot be expressive."
|
|
- EVALUATION_SYSTEM_LIST.append(SeamlessStreamingS2TAgent)
|
|
|
|
|
|
+ agent_class = SeamlessStreamingS2TAgent
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
"--unity-model-name",
|
|
"--unity-model-name",
|
|
type=str,
|
|
type=str,
|
|
@@ -97,15 +95,7 @@ def main() -> None:
|
|
dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
|
|
dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
|
|
)
|
|
)
|
|
|
|
|
|
- system, args = build_system_args(
|
|
|
|
- {**base_config, **model_configs, **eval_configs}, parser
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if args.fp16:
|
|
|
|
- logger.warn("--fp16 arg will be ignorned, use --dtype instead")
|
|
|
|
-
|
|
|
|
- evaluator = build_evaluator(args)
|
|
|
|
- evaluator(system)
|
|
|
|
|
|
+ evaluate(agent_class, {**base_config, **model_configs, **eval_configs}, parser)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|