|
@@ -215,7 +215,7 @@ def main() -> None:
|
|
|
"--ref_field",
|
|
|
type=str,
|
|
|
help="Reference target text field to compute the BLEU score against.",
|
|
|
- default="tgt_text",
|
|
|
+ default=None,
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--duration_factor",
|
|
@@ -335,7 +335,8 @@ def main() -> None:
|
|
|
)
|
|
|
|
|
|
hyps += [str(s) for s in text_output]
|
|
|
- refs += [str(s) for s in example[args.ref_field]]
|
|
|
+ if args.ref_field is not None and args.ref_field in example:
|
|
|
+ refs += [str(s) for s in example[args.ref_field]]
|
|
|
|
|
|
for i in range(len(text_output)):
|
|
|
t = text_output[i]
|
|
@@ -357,8 +358,8 @@ def main() -> None:
|
|
|
progress_bar.close()
|
|
|
logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
|
|
|
|
|
|
- assert len(hyps) == len(refs)
|
|
|
- if len(hyps) > 0:
|
|
|
+ if len(hyps) == len(refs):
|
|
|
+ logger.info(f"Calculating S2T BLEU using {args.ref_field} column")
|
|
|
if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):
|
|
|
tokenizer = "char"
|
|
|
else:
|