|
@@ -4,21 +4,17 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
-from typing import Optional
|
|
|
import argparse
|
|
|
import contextlib
|
|
|
import logging
|
|
|
from argparse import Namespace
|
|
|
from pathlib import Path
|
|
|
-from tqdm import tqdm
|
|
|
+from typing import Optional
|
|
|
|
|
|
+import pandas as pd
|
|
|
import torch
|
|
|
import torchaudio
|
|
|
-from fairseq2.data import (
|
|
|
- Collater,
|
|
|
- DataPipeline,
|
|
|
- FileMapper,
|
|
|
-)
|
|
|
+from fairseq2.data import Collater, DataPipeline, FileMapper
|
|
|
from fairseq2.data.audio import (
|
|
|
AudioDecoder,
|
|
|
WaveformToFbankConverter,
|
|
@@ -28,6 +24,7 @@ from fairseq2.data.text import StrSplitter, read_text
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
from sacrebleu.metrics import BLEU # type: ignore[attr-defined]
|
|
|
from torch import Tensor
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
|
|
|
PretsselGenerator,
|
|
@@ -142,6 +139,12 @@ def main() -> None:
|
|
|
help="The duration factor for NAR T2U model. Expressivity model uses 1.1",
|
|
|
default=1.1,
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--output_result_tsv",
|
|
|
+ type=bool,
|
|
|
+ help="Whether to output results in tsv format (for full-blown evaluation)",
|
|
|
+ default=True,
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
@@ -273,6 +276,24 @@ def main() -> None:
|
|
|
progress_bar.close()
|
|
|
logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
|
|
|
|
|
|
+ if args.output_result_tsv:
|
|
|
+ output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
|
|
|
+ output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
|
|
|
+ text_out = []
|
|
|
+ with open(hyp_file.name) as file:
|
|
|
+ for line in file:
|
|
|
+ text_out.append(line.strip())
|
|
|
+
|
|
|
+ unit_out = []
|
|
|
+ with open(unit_file.name) as file:
|
|
|
+ for line in file:
|
|
|
+ unit_out.append(line.strip())
|
|
|
+
|
|
|
+ output_tsv["s2t_out"] = text_out
|
|
|
+ output_tsv["orig_unit"] = unit_out
|
|
|
+ output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
|
|
|
+ logger.info(f"Output results in {output_tsv_file}")
|
|
|
+
|
|
|
if len(hyps) == len(refs):
|
|
|
logger.info(f"Calculating S2T BLEU using {args.ref_field} column")
|
|
|
if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):
|