Kaynağa Gözat

Don't mandate audio_root_dir for text input eval. (#69)

Kaushik Ram Sadagopan 1 yıl önce
ebeveyn
işleme
3b34f6defa

+ 1 - 1
scripts/m4t/evaluate/README.md

@@ -7,7 +7,7 @@ Evaluation can be run with the CLI, from the root directory of the repository.
 The model can be specified with `--model_name`: `seamlessM4T_v2_large` or `seamlessM4T_large` or `seamlessM4T_medium`
 
 ```bash
-m4t_evaluate <path_to_data_tsv_file> <task_name> <tgt_lang> --output_path <path_to_save_audio> --ref_field <ref_field_name> --audio_root_dir <path_to_audio_root_directory>
+m4t_evaluate <path_to_data_tsv_file> <task_name> <tgt_lang> --output_path <path_to_save_evaluation_output> --ref_field <ref_field_name> --audio_root_dir <path_to_audio_root_directory>
 ```
 
 ### S2TT

+ 9 - 2
scripts/m4t/evaluate/evaluate.py

@@ -60,7 +60,7 @@ class EvalContext:
     data_file: Path
     """The pathname of the test TSV data file."""
 
-    audio_root_dir: Path
+    audio_root_dir: Optional[Path]
     """The pathname of the directory under which
     audio files are stored."""
 
@@ -120,6 +120,8 @@ def build_data_pipeline(
     pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
 
     if ctx.input_modality == Modality.SPEECH:
+        assert ctx.audio_root_dir is not None
+
         map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
 
         pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
@@ -339,7 +341,7 @@ def main():
         "--audio_root_dir",
         type=str,
         help="Root directory for the audio filenames in the data file.",
-        required=True,
+        default="",
     )
     parser.add_argument(
         "--ref_field",
@@ -351,6 +353,11 @@ def main():
 
     input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
 
+    if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
+        raise ValueError(
+            f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
+        )
+
     if torch.cuda.is_available():
         device = torch.device("cuda:0")
         dtype = torch.float16