|
@@ -65,10 +65,7 @@ def build_data_pipeline(
|
|
|
|
|
|
n_parallel = 4
|
|
|
|
|
|
- split_tsv = StrSplitter(
|
|
|
- names=["id", "audio"],
|
|
|
- indices=[header.index("id"), header.index(args.audio_field)],
|
|
|
- )
|
|
|
+ split_tsv = StrSplitter(names=header)
|
|
|
|
|
|
pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
|
|
|
|
|
@@ -76,7 +73,7 @@ def build_data_pipeline(
|
|
|
|
|
|
map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
|
|
|
|
|
|
- pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
|
|
|
+ pipeline_builder.map(map_file, selector=args.audio_field, num_parallel_calls=n_parallel)
|
|
|
|
|
|
decode_audio = AudioDecoder(dtype=torch.float32, device=device)
|
|
|
|
|
@@ -98,7 +95,7 @@ def build_data_pipeline(
|
|
|
|
|
|
pipeline_builder.map(
|
|
|
[decode_audio, convert_to_fbank, normalize_fbank],
|
|
|
- selector="audio.data",
|
|
|
+ selector=f"{args.audio_field}.data",
|
|
|
num_parallel_calls=n_parallel,
|
|
|
)
|
|
|
|
|
@@ -227,7 +224,7 @@ def main() -> None:
|
|
|
sample_id = 0
|
|
|
for example in pipeline:
|
|
|
valid_sequences: Optional[Tensor] = None
|
|
|
- src = example["audio"]["data"]["fbank"]
|
|
|
+ src = example[args.audio_field]["data"]["fbank"]
|
|
|
# Skip corrupted audio tensors.
|
|
|
valid_sequences = ~torch.any(
|
|
|
torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
|
|
@@ -241,7 +238,7 @@ def main() -> None:
|
|
|
|
|
|
# Skip performing inference when the input is entirely corrupted.
|
|
|
if src["seqs"].numel() > 0:
|
|
|
- prosody_encoder_input = example["audio"]["data"]["gcmvn_fbank"]
|
|
|
+ prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
|
|
|
text_output, unit_output = translator.predict(
|
|
|
src,
|
|
|
args.task,
|