|
@@ -60,7 +60,7 @@ class EvalContext:
|
|
|
data_file: Path
|
|
|
"""The pathname of the test TSV data file."""
|
|
|
|
|
|
- audio_root_dir: Path
|
|
|
+ audio_root_dir: Optional[Path]
|
|
|
"""The pathname of the directory under which
|
|
|
audio files are stored."""
|
|
|
|
|
@@ -120,6 +120,8 @@ def build_data_pipeline(
|
|
|
pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
|
|
|
|
|
|
if ctx.input_modality == Modality.SPEECH:
|
|
|
+ assert ctx.audio_root_dir is not None
|
|
|
+
|
|
|
map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
|
|
|
|
|
|
pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
|
|
@@ -339,7 +341,7 @@ def main():
|
|
|
"--audio_root_dir",
|
|
|
type=str,
|
|
|
help="Root directory for the audio filenames in the data file.",
|
|
|
- required=True,
|
|
|
+ default="",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--ref_field",
|
|
@@ -351,6 +353,11 @@ def main():
|
|
|
|
|
|
input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
|
|
|
|
|
|
+ if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
|
|
|
+ raise ValueError(
|
|
|
+ f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
|
|
|
+ )
|
|
|
+
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.device("cuda:0")
|
|
|
dtype = torch.float16
|