Browse Source

Enable using `m4t_evaluate` with a Manifest JSON file (#395)

* Add capability to accept manifest json

* Fix a missing line

* Fix pipline data fields

* Make `output_path` required
Alisamar Husain 1 year ago
parent
commit
6e047a9ae1

+ 2 - 1
.gitignore

@@ -1,5 +1,6 @@
-# JetBrains PyCharm IDE
+# Editors
 .idea/
+.vscode/
 
 # Byte-compiled / optimized / DLL files
 __pycache__/

+ 50 - 18
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -9,6 +9,7 @@ import contextlib
 import itertools
 import logging
 import subprocess
+import json
 from argparse import Namespace
 from dataclasses import dataclass
 from pathlib import Path
@@ -63,7 +64,10 @@ class EvalContext:
     """The name of the S2T UnitY model."""
 
     data_file: Path
-    """The pathname of the test TSV data file."""
+    """The pathname of the test data file, TSV or manifest JSON."""
+    
+    data_file_type: str
+    """Type of data file, TSV or manifest JSON."""
 
     audio_root_dir: Optional[Path]
     """The pathname of the directory under which
@@ -113,17 +117,36 @@ def build_data_pipeline(
     ctx: EvalContext,
     text_tokenizer: TextTokenizer,
 ) -> DataPipeline:
-    with open(ctx.data_file, "r") as f:
-        header = f.readline().strip("\n").split("\t")
-        first_example = f.readline().strip("\n").split("\t")
+    
+    if ctx.data_file_type == "TSV":
+        with open(ctx.data_file, "r") as f:
+            header = f.readline().strip("\n").split("\t")
+            first_example = f.readline().strip("\n").split("\t")
+
+        format_tsv = StrSplitter(names=header)
+        pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(format_tsv)
+        
+    elif ctx.data_file_type == "JSON":
+        def format_json(line: str):
+            example = json.loads(str(line))
+            return {
+                "src_text": example["source"]["text"],
+                "src_lang": example["source"]["lang"],
+                "audio": example["source"]["audio_local_path"],
+                "tgt_text": example["target"]["text"],
+            }
+        
+        with open(ctx.data_file, "r") as f:
+            header = list(format_json(f.readline()).keys())
+            first_example = list(format_json(f.readline()).values())
+            
+        pipeline_builder = read_text(ctx.data_file, rtrim=True).map(format_json)
+        
+    else:
+        raise NotImplementedError
 
     # TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
     n_parallel = 4
-
-    split_tsv = StrSplitter(names=header)
-
-    pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
-
     if ctx.input_modality == Modality.SPEECH:
         assert ctx.audio_root_dir is not None
 
@@ -334,7 +357,10 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         description="M4T evaluation for tasks supported by Translator."
     )
     parser.add_argument(
-        "--data_file", type=str, help="Data file (.tsv) to be evaluated."
+        "--data_file", 
+        type=str, 
+        help="Data file to be evaluated, either TSV file or manifest JSON file."
+        "Format of the manifest JSON file should be that as produced by `m4t_prepare_dataset`"
     )
 
     parser = add_inference_arguments(parser)
@@ -367,14 +393,19 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
     default_args.update(optional_args) if optional_args else default_args
     args = Namespace(**default_args)
 
-    if not args.data_file or not args.task or not args.tgt_lang:
-        raise Exception(
-            "Please provide required arguments for evaluation - data_file, task, tgt_lang"
-        )
-
-    if not Path(args.data_file).exists():
-        raise ValueError(f"Invalid data_file to be evaluated: {args.data_file}")
-
+    assert args.data_file and args.task and args.tgt_lang and args.output_path, \
+        "Please provide required arguments for evaluation - data_file, task, tgt_lang"
+        
+    assert Path(args.data_file).exists(), \
+        f"Invalid `data_file`: {args.data_file} does not exist"
+        
+    if Path(args.data_file).suffix == ".tsv":
+        data_type = "TSV"
+    elif Path(args.data_file).suffix == ".json":
+        data_type = "JSON"
+    else:
+        raise ValueError("Unable to recognize file type! Please use a data_file with either .tsv or .json extension.")
+    
     input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
 
     if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
@@ -418,6 +449,7 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         output_modality=output_modality,
         model_name=args.model_name,
         data_file=Path(args.data_file),
+        data_file_type=data_type,
         audio_root_dir=Path(args.audio_root_dir),
         target_lang=args.tgt_lang,
         source_lang=args.src_lang,

+ 10 - 1
src/seamless_communication/cli/m4t/predict/predict.py

@@ -24,7 +24,16 @@ logger = logging.getLogger(__name__)
 
 
 def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
-    parser.add_argument("--task", type=str, help="Task type")
+    parser.add_argument(
+        "--task", 
+        type=str, 
+        choices=["ASR", "S2ST", "S2TT"],
+        help=(
+            "* `ASR` -- automatic speech recognition (transcription);"
+            "* `S2ST` -- speech to speech translation;"
+            "* `S2TT` -- speech to text translation;"
+        )
+    )
     parser.add_argument(
         "--tgt_lang", type=str, help="Target language to translate/transcribe into."
     )