Browse Source

Add output_result_tsv flag to pretssel_inference.py to output results in a tsv for evaluation (#163)

Yilin Yang 1 year ago
parent
commit
65ac472a1c

+ 28 - 7
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -4,21 +4,17 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from typing import Optional
 import argparse
 import argparse
 import contextlib
 import contextlib
 import logging
 import logging
 from argparse import Namespace
 from argparse import Namespace
 from pathlib import Path
 from pathlib import Path
-from tqdm import tqdm
+from typing import Optional
 
 
+import pandas as pd
 import torch
 import torch
 import torchaudio
 import torchaudio
-from fairseq2.data import (
-    Collater,
-    DataPipeline,
-    FileMapper,
-)
+from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import (
 from fairseq2.data.audio import (
     AudioDecoder,
     AudioDecoder,
     WaveformToFbankConverter,
     WaveformToFbankConverter,
@@ -28,6 +24,7 @@ from fairseq2.data.text import StrSplitter, read_text
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor
 from torch import Tensor
+from tqdm import tqdm
 
 
 from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
 from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
     PretsselGenerator,
     PretsselGenerator,
@@ -142,6 +139,12 @@ def main() -> None:
         help="The duration factor for NAR T2U model. Expressivity model uses 1.1",
         help="The duration factor for NAR T2U model. Expressivity model uses 1.1",
         default=1.1,
         default=1.1,
     )
     )
+    parser.add_argument(
+        "--output_result_tsv",
+        type=bool,
+        help="Whether to output results in tsv format (for full-blown evaluation)",
+        default=True,
+    )
     args = parser.parse_args()
     args = parser.parse_args()
 
 
     if torch.cuda.is_available():
     if torch.cuda.is_available():
@@ -273,6 +276,24 @@ def main() -> None:
     progress_bar.close()
     progress_bar.close()
     logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
     logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
 
 
+    if args.output_result_tsv:
+        output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
+        output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
+        text_out = []
+        with open(hyp_file.name) as file:
+            for line in file:
+                text_out.append(line.strip())
+
+        unit_out = []
+        with open(unit_file.name) as file:
+            for line in file:
+                unit_out.append(line.strip())
+
+        output_tsv["s2t_out"] = text_out
+        output_tsv["orig_unit"] = unit_out
+        output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
+        logger.info(f"Output results in {output_tsv_file}")
+
     if len(hyps) == len(refs):
     if len(hyps) == len(refs):
         logger.info(f"Calculating S2T BLEU using {args.ref_field} column")
         logger.info(f"Calculating S2T BLEU using {args.ref_field} column")
         if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):
         if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):