Sfoglia il codice sorgente

Re-do audio_field again (#220)

Co-authored-by: Peng-Jen Chen <pipibjc@devfair0209.h2.fair>
pipibjc 1 anno fa
parent
commit
4e78282fde

+ 5 - 8
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -65,10 +65,7 @@ def build_data_pipeline(
 
     n_parallel = 4
 
-    split_tsv = StrSplitter(
-        names=["id", "audio"],
-        indices=[header.index("id"), header.index(args.audio_field)],
-    )
+    split_tsv = StrSplitter(names=header)
 
     pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
 
@@ -76,7 +73,7 @@ def build_data_pipeline(
 
     map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
 
-    pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
+    pipeline_builder.map(map_file, selector=args.audio_field, num_parallel_calls=n_parallel)
 
     decode_audio = AudioDecoder(dtype=torch.float32, device=device)
 
@@ -98,7 +95,7 @@ def build_data_pipeline(
 
     pipeline_builder.map(
         [decode_audio, convert_to_fbank, normalize_fbank],
-        selector="audio.data",
+        selector=f"{args.audio_field}.data",
         num_parallel_calls=n_parallel,
     )
 
@@ -227,7 +224,7 @@ def main() -> None:
         sample_id = 0
         for example in pipeline:
             valid_sequences: Optional[Tensor] = None
-            src = example["audio"]["data"]["fbank"]
+            src = example[args.audio_field]["data"]["fbank"]
             # Skip corrupted audio tensors.
             valid_sequences = ~torch.any(
                 torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
@@ -241,7 +238,7 @@ def main() -> None:
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                prosody_encoder_input = example["audio"]["data"]["gcmvn_fbank"]
+                prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
                 text_output, unit_output = translator.predict(
                     src,
                     args.task,