|
@@ -8,7 +8,7 @@
|
|
import json
|
|
import json
|
|
import logging
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
-from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
|
|
+from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
@@ -100,6 +100,7 @@ class UnitYDataLoader:
|
|
unit_tokenizer: UnitTokenizer,
|
|
unit_tokenizer: UnitTokenizer,
|
|
dataset_manifest_path: str,
|
|
dataset_manifest_path: str,
|
|
batching_config: BatchingConfig,
|
|
batching_config: BatchingConfig,
|
|
|
|
+ max_src_tokens_per_batch: int = 100000
|
|
):
|
|
):
|
|
self.text_tokenizer = text_tokenizer
|
|
self.text_tokenizer = text_tokenizer
|
|
self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
|
|
self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
|
|
@@ -115,6 +116,7 @@ class UnitYDataLoader:
|
|
"dtype": self.batching_config.float_dtype,
|
|
"dtype": self.batching_config.float_dtype,
|
|
}
|
|
}
|
|
self.dataset = self._load_manifest(dataset_manifest_path)
|
|
self.dataset = self._load_manifest(dataset_manifest_path)
|
|
|
|
+ self.max_src_tokens_per_batch = max_src_tokens_per_batch
|
|
|
|
|
|
def get_dataloader(self) -> DataLoader[SeqsBatch]:
|
|
def get_dataloader(self) -> DataLoader[SeqsBatch]:
|
|
subset = split_dataset_by_node(
|
|
subset = split_dataset_by_node(
|
|
@@ -156,9 +158,9 @@ class UnitYDataLoader:
|
|
"""Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
|
|
"""Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
|
|
target_lang = sample.target.lang
|
|
target_lang = sample.target.lang
|
|
if target_lang not in self.text_encoders_per_lang:
|
|
if target_lang not in self.text_encoders_per_lang:
|
|
- self.text_encoders_per_lang[
|
|
|
|
- target_lang
|
|
|
|
- ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
|
|
|
|
|
+ self.text_encoders_per_lang[target_lang] = (
|
|
|
|
+ self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
|
|
|
+ )
|
|
tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
|
|
tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
|
|
eos_idx = self.text_tokenizer.vocab_info.eos_idx
|
|
eos_idx = self.text_tokenizer.vocab_info.eos_idx
|
|
tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
|
|
tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
|
|
@@ -170,9 +172,9 @@ class UnitYDataLoader:
|
|
return None
|
|
return None
|
|
target_lang = sample.target.lang
|
|
target_lang = sample.target.lang
|
|
if target_lang not in self.unit_encoders_per_lang:
|
|
if target_lang not in self.unit_encoders_per_lang:
|
|
- self.unit_encoders_per_lang[
|
|
|
|
- target_lang
|
|
|
|
- ] = self.unit_tokenizer.create_encoder(lang=target_lang)
|
|
|
|
|
|
+ self.unit_encoders_per_lang[target_lang] = (
|
|
|
|
+ self.unit_tokenizer.create_encoder(lang=target_lang)
|
|
|
|
+ )
|
|
tokens = self.unit_encoders_per_lang[target_lang](
|
|
tokens = self.unit_encoders_per_lang[target_lang](
|
|
torch.LongTensor(sample.target.units).unsqueeze(0)
|
|
torch.LongTensor(sample.target.units).unsqueeze(0)
|
|
)
|
|
)
|
|
@@ -195,20 +197,41 @@ class UnitYDataLoader:
|
|
length_s: float = max(wav.shape) / sample_rate
|
|
length_s: float = max(wav.shape) / sample_rate
|
|
return length_s > self.batching_config.max_audio_length_sec
|
|
return length_s > self.batching_config.max_audio_length_sec
|
|
|
|
|
|
|
|
+ def _drop_overflow_samples(
|
|
|
|
+ self, samples_with_fbanks: List[Tuple[LangPairSample, torch.Tensor]]
|
|
|
|
+ ) -> List[Tuple[LangPairSample, torch.Tensor]]:
|
|
|
|
+ # filter by src_tokens length (reverse)
|
|
|
|
+ samples_with_fbanks = sorted(
|
|
|
|
+ samples_with_fbanks, key=lambda sb: -sb[1].shape[0]
|
|
|
|
+ )
|
|
|
|
+ bwd = samples_with_fbanks[0][1].shape[0]
|
|
|
|
+ max_samples_for_batch = min(1, self.max_src_tokens_per_batch // bwd)
|
|
|
|
+ if max_samples_for_batch < len(samples_with_fbanks):
|
|
|
|
+ samples_with_fbanks = samples_with_fbanks[:max_samples_for_batch]
|
|
|
|
+ return samples_with_fbanks
|
|
|
|
+
|
|
def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
|
|
def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
|
|
samples = [LangPairSample.from_json(sample) for sample in raw_samples]
|
|
samples = [LangPairSample.from_json(sample) for sample in raw_samples]
|
|
# input speech
|
|
# input speech
|
|
# - filter long audio samples
|
|
# - filter long audio samples
|
|
- filtered_samples = [sample for sample in samples if not self._is_long_src_audio(sample)]
|
|
|
|
- samples = filtered_samples if filtered_samples else [samples[0]] # keep at least one sample
|
|
|
|
- src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
|
|
|
|
|
|
+ filtered_samples = [
|
|
|
|
+ sample for sample in samples if not self._is_long_src_audio(sample)
|
|
|
|
+ ]
|
|
|
|
+ samples = (
|
|
|
|
+ filtered_samples if filtered_samples else [samples[0]]
|
|
|
|
+ ) # keep at least one sample
|
|
|
|
+ with_fbanks = [(sample, self._get_source_fbank(sample)) for sample in samples]
|
|
# - filter NaNs in fbanks
|
|
# - filter NaNs in fbanks
|
|
- with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
|
|
|
|
- samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
|
|
|
|
- assert len(samples) > 0
|
|
|
|
- src_tokens_list = [
|
|
|
|
- src_toks for src_toks, skip in zip(src_tokens_list, with_nans) if not skip
|
|
|
|
|
|
+ filtered = [
|
|
|
|
+ (sample, fbank)
|
|
|
|
+ for sample, fbank in with_fbanks
|
|
|
|
+ if not fbank.isnan().any().item()
|
|
]
|
|
]
|
|
|
|
+ filtered = self._drop_overflow_samples(filtered)
|
|
|
|
+
|
|
|
|
+ samples = [sample for sample, _ in filtered]
|
|
|
|
+ src_tokens_list = [src_tokens for _, src_tokens in filtered]
|
|
|
|
+ assert len(samples) > 0
|
|
src_tokens = self._batch_tensors(
|
|
src_tokens = self._batch_tensors(
|
|
src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
|
|
src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
|
|
).to(self.batching_config.float_dtype)
|
|
).to(self.batching_config.float_dtype)
|