Pārlūkot izejas kodu

Add output_result_tsv flag to pretssel_inference.py to output results in a tsv for evaluation (#163)

Yilin Yang 1 gadu atpakaļ
vecāks
revīzija
65ac472a1c

+ 28 - 7
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -4,21 +4,17 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Optional
 import argparse
 import contextlib
 import logging
 from argparse import Namespace
 from pathlib import Path
-from tqdm import tqdm
+from typing import Optional
 
+import pandas as pd
 import torch
 import torchaudio
-from fairseq2.data import (
-    Collater,
-    DataPipeline,
-    FileMapper,
-)
+from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import (
     AudioDecoder,
     WaveformToFbankConverter,
@@ -28,6 +24,7 @@ from fairseq2.data.text import StrSplitter, read_text
 from fairseq2.typing import DataType, Device
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor
+from tqdm import tqdm
 
 from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
     PretsselGenerator,
@@ -142,6 +139,12 @@ def main() -> None:
         help="The duration factor for NAR T2U model. Expressivity model uses 1.1",
         default=1.1,
     )
+    parser.add_argument(
+        "--output_result_tsv",
+        type=bool,
+        help="Whether to output results in tsv format (for full-blown evaluation)",
+        default=True,
+    )
     args = parser.parse_args()
 
     if torch.cuda.is_available():
@@ -273,6 +276,24 @@ def main() -> None:
     progress_bar.close()
     logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
 
+    if args.output_result_tsv:
+        output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
+        output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
+        text_out = []
+        with open(hyp_file.name) as file:
+            for line in file:
+                text_out.append(line.strip())
+
+        unit_out = []
+        with open(unit_file.name) as file:
+            for line in file:
+                unit_out.append(line.strip())
+
+        output_tsv["s2t_out"] = text_out
+        output_tsv["orig_unit"] = unit_out
+        output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
+        logger.info(f"Output results in {output_tsv_file}")
+
     if len(hyps) == len(refs):
         logger.info(f"Calculating S2T BLEU using {args.ref_field} column")
         if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):