소스 검색

Support not providing reference text column (#161)

* Support not providing reference text

* default to None, to avoid name conflict
Yilin Yang 1 년 전
부모
커밋
e8a2a9f74b
1개의 변경된 파일5개의 추가작업 그리고 4개의 파일을 삭제
  1. 5 4
      src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

+ 5 - 4
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -215,7 +215,7 @@ def main() -> None:
         "--ref_field",
         type=str,
         help="Reference target text field to compute the BLEU score against.",
-        default="tgt_text",
+        default=None,
     )
     parser.add_argument(
         "--duration_factor",
@@ -335,7 +335,8 @@ def main() -> None:
                 )
 
             hyps += [str(s) for s in text_output]
-            refs += [str(s) for s in example[args.ref_field]]
+            if args.ref_field is not None and args.ref_field in example:
+                refs += [str(s) for s in example[args.ref_field]]
 
             for i in range(len(text_output)):
                 t = text_output[i]
@@ -357,8 +358,8 @@ def main() -> None:
     progress_bar.close()
     logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
 
-    assert len(hyps) == len(refs)
-    if len(hyps) > 0:
+    if len(hyps) == len(refs):
+        logger.info(f"Calculating S2T BLEU using {args.ref_field} column")
         if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):
             tokenizer = "char"
         else: