|
@@ -9,6 +9,7 @@ import contextlib
|
|
|
import itertools
|
|
|
import logging
|
|
|
import subprocess
|
|
|
+import json
|
|
|
from argparse import Namespace
|
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
@@ -63,7 +64,10 @@ class EvalContext:
|
|
|
"""The name of the S2T UnitY model."""
|
|
|
|
|
|
data_file: Path
|
|
|
- """The pathname of the test TSV data file."""
|
|
|
+ """The pathname of the test data file, TSV or manifest JSON."""
|
|
|
+
|
|
|
+ data_file_type: str
|
|
|
+ """Type of data file, TSV or manifest JSON."""
|
|
|
|
|
|
audio_root_dir: Optional[Path]
|
|
|
"""The pathname of the directory under which
|
|
@@ -113,17 +117,36 @@ def build_data_pipeline(
|
|
|
ctx: EvalContext,
|
|
|
text_tokenizer: TextTokenizer,
|
|
|
) -> DataPipeline:
|
|
|
- with open(ctx.data_file, "r") as f:
|
|
|
- header = f.readline().strip("\n").split("\t")
|
|
|
- first_example = f.readline().strip("\n").split("\t")
|
|
|
+
|
|
|
+ if ctx.data_file_type == "TSV":
|
|
|
+ with open(ctx.data_file, "r") as f:
|
|
|
+ header = f.readline().strip("\n").split("\t")
|
|
|
+ first_example = f.readline().strip("\n").split("\t")
|
|
|
+
|
|
|
+ format_tsv = StrSplitter(names=header)
|
|
|
+ pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(format_tsv)
|
|
|
+
|
|
|
+ elif ctx.data_file_type == "JSON":
|
|
|
+ def format_json(line: str):
|
|
|
+ example = json.loads(str(line))
|
|
|
+ return {
|
|
|
+ "src_text": example["source"]["text"],
|
|
|
+ "src_lang": example["source"]["lang"],
|
|
|
+ "audio": example["source"]["audio_local_path"],
|
|
|
+ "tgt_text": example["target"]["text"],
|
|
|
+ }
|
|
|
+
|
|
|
+ with open(ctx.data_file, "r") as f:
|
|
|
+ header = list(format_json(f.readline()).keys())
|
|
|
+ first_example = list(format_json(f.readline()).values())
|
|
|
+
|
|
|
+ pipeline_builder = read_text(ctx.data_file, rtrim=True).map(format_json)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
# TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
|
|
|
n_parallel = 4
|
|
|
-
|
|
|
- split_tsv = StrSplitter(names=header)
|
|
|
-
|
|
|
- pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
|
|
|
-
|
|
|
if ctx.input_modality == Modality.SPEECH:
|
|
|
assert ctx.audio_root_dir is not None
|
|
|
|
|
@@ -334,7 +357,10 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
|
|
|
description="M4T evaluation for tasks supported by Translator."
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "--data_file", type=str, help="Data file (.tsv) to be evaluated."
|
|
|
+ "--data_file",
|
|
|
+ type=str,
|
|
|
+ help="Data file to be evaluated, either TSV file or manifest JSON file."
|
|
|
+ "Format of the manifest JSON file should be that as produced by `m4t_prepare_dataset`"
|
|
|
)
|
|
|
|
|
|
parser = add_inference_arguments(parser)
|
|
@@ -367,14 +393,19 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
|
|
|
default_args.update(optional_args) if optional_args else default_args
|
|
|
args = Namespace(**default_args)
|
|
|
|
|
|
- if not args.data_file or not args.task or not args.tgt_lang:
|
|
|
- raise Exception(
|
|
|
- "Please provide required arguments for evaluation - data_file, task, tgt_lang"
|
|
|
- )
|
|
|
-
|
|
|
- if not Path(args.data_file).exists():
|
|
|
- raise ValueError(f"Invalid data_file to be evaluated: {args.data_file}")
|
|
|
-
|
|
|
+ assert args.data_file and args.task and args.tgt_lang and args.output_path, \
|
|
|
+ "Please provide required arguments for evaluation - data_file, task, tgt_lang"
|
|
|
+
|
|
|
+ assert Path(args.data_file).exists(), \
|
|
|
+ f"Invalid `data_file`: {args.data_file} does not exist"
|
|
|
+
|
|
|
+ if Path(args.data_file).suffix == ".tsv":
|
|
|
+ data_type = "TSV"
|
|
|
+ elif Path(args.data_file).suffix == ".json":
|
|
|
+ data_type = "JSON"
|
|
|
+ else:
|
|
|
+ raise ValueError("Unable to recognize file type! Please use a data_file with either .tsv or .json extension.")
|
|
|
+
|
|
|
input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
|
|
|
|
|
|
if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
|
|
@@ -418,6 +449,7 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
|
|
|
output_modality=output_modality,
|
|
|
model_name=args.model_name,
|
|
|
data_file=Path(args.data_file),
|
|
|
+ data_file_type=data_type,
|
|
|
audio_root_dir=Path(args.audio_root_dir),
|
|
|
target_lang=args.tgt_lang,
|
|
|
source_lang=args.src_lang,
|