|
@@ -76,7 +76,9 @@ class UnityDataLoader:
|
|
self.manifest_paths = list(self._iterate_manifest_paths())
|
|
self.manifest_paths = list(self._iterate_manifest_paths())
|
|
self.text_tokenizer = self._init_text_tokenizer()
|
|
self.text_tokenizer = self._init_text_tokenizer()
|
|
self.unit_tokenizer = self._init_unit_tokenizer()
|
|
self.unit_tokenizer = self._init_unit_tokenizer()
|
|
- self.spm_encoder = SentencePieceEncoder(model=self.text_tokenizer.model, suffix_tokens=["</s>"])
|
|
|
|
|
|
+ self.spm_encoder = SentencePieceEncoder(
|
|
|
|
+ model=self.text_tokenizer.model, suffix_tokens=["</s>"]
|
|
|
|
+ )
|
|
self.text_prefix_tokens = self._build_text_tgt_prefixes()
|
|
self.text_prefix_tokens = self._build_text_tgt_prefixes()
|
|
self.unit_prefix_tokens = self._build_unit_tgt_prefixes()
|
|
self.unit_prefix_tokens = self._build_unit_tgt_prefixes()
|
|
if self.config.fixed_batch_size is None:
|
|
if self.config.fixed_batch_size is None:
|
|
@@ -88,18 +90,25 @@ class UnityDataLoader:
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def _set_mkl_num_threads(cls):
|
|
def _set_mkl_num_threads(cls):
|
|
- """ Setting mkl num threads to 1, so that we don't get thread explosion."""
|
|
|
|
- mkl_rt = ctypes.CDLL('libmkl_rt.so')
|
|
|
|
|
|
+ """Setting mkl num threads to 1, so that we don't get thread explosion."""
|
|
|
|
+ mkl_rt = ctypes.CDLL("libmkl_rt.so")
|
|
mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1)))
|
|
mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1)))
|
|
|
|
|
|
def _calculate_tgt_text_batch_shapes(self) -> List[Tuple[int, int]]:
|
|
def _calculate_tgt_text_batch_shapes(self) -> List[Tuple[int, int]]:
|
|
max_seq_len = self.config.max_tgt_text_tokens_per_sample
|
|
max_seq_len = self.config.max_tgt_text_tokens_per_sample
|
|
max_tokens_per_batch = self.config.max_tgt_text_tokens_per_batch
|
|
max_tokens_per_batch = self.config.max_tgt_text_tokens_per_batch
|
|
assert max_tokens_per_batch is not None, "max_tokens_per_batch is not set"
|
|
assert max_tokens_per_batch is not None, "max_tokens_per_batch is not set"
|
|
|
|
+ max_bsz = (
|
|
|
|
+ self.config.max_batch_size
|
|
|
|
+ if self.config.max_batch_size is not None
|
|
|
|
+ else max_tokens_per_batch
|
|
|
|
+ )
|
|
step = self.BATCH_WIDTH_STEP
|
|
step = self.BATCH_WIDTH_STEP
|
|
bucket_sizes = []
|
|
bucket_sizes = []
|
|
for seq_len in range(step, max(step, max_seq_len) + 1, step):
|
|
for seq_len in range(step, max(step, max_seq_len) + 1, step):
|
|
bsz = max(1, max_tokens_per_batch // seq_len)
|
|
bsz = max(1, max_tokens_per_batch // seq_len)
|
|
|
|
+ if bsz > max_bsz:
|
|
|
|
+ continue
|
|
bucket_sizes.append((bsz, seq_len))
|
|
bucket_sizes.append((bsz, seq_len))
|
|
return bucket_sizes
|
|
return bucket_sizes
|
|
|
|
|
|
@@ -128,7 +137,8 @@ class UnityDataLoader:
|
|
assert self.config.text_tokenization.langtoks is not None
|
|
assert self.config.text_tokenization.langtoks is not None
|
|
assert self.config.text_tokenization.spm_path is not None
|
|
assert self.config.text_tokenization.spm_path is not None
|
|
return SPMTokenizer(
|
|
return SPMTokenizer(
|
|
- pathname=self.config.text_tokenization.spm_path, langs=self.config.text_tokenization.langtoks
|
|
|
|
|
|
+ pathname=self.config.text_tokenization.spm_path,
|
|
|
|
+ langs=self.config.text_tokenization.langtoks,
|
|
)
|
|
)
|
|
|
|
|
|
def _init_unit_tokenizer(self) -> UnitTokenizer:
|
|
def _init_unit_tokenizer(self) -> UnitTokenizer:
|
|
@@ -154,7 +164,9 @@ class UnityDataLoader:
|
|
def _infer_manifest_full_path(self, manifest_name: str) -> str:
|
|
def _infer_manifest_full_path(self, manifest_name: str) -> str:
|
|
full_path = manifest_name.strip()
|
|
full_path = manifest_name.strip()
|
|
if self.config.manifest_path_prefix is not None:
|
|
if self.config.manifest_path_prefix is not None:
|
|
- full_path = os.path.join(self.config.manifest_path_prefix.strip(), full_path)
|
|
|
|
|
|
+ full_path = os.path.join(
|
|
|
|
+ self.config.manifest_path_prefix.strip(), full_path
|
|
|
|
+ )
|
|
if not full_path.endswith(self.MANIFEST_EXT) and not os.path.exists(full_path):
|
|
if not full_path.endswith(self.MANIFEST_EXT) and not os.path.exists(full_path):
|
|
full_path += self.MANIFEST_EXT
|
|
full_path += self.MANIFEST_EXT
|
|
if not os.path.exists(full_path):
|
|
if not os.path.exists(full_path):
|
|
@@ -188,7 +200,9 @@ class UnityDataLoader:
|
|
self.TARGET_LANG_COLUMN,
|
|
self.TARGET_LANG_COLUMN,
|
|
]:
|
|
]:
|
|
if column not in column_names:
|
|
if column not in column_names:
|
|
- raise ValueError(f"Column `{column}` is not present in `{manifest_path}` ")
|
|
|
|
|
|
+ raise ValueError(
|
|
|
|
+ f"Column `{column}` is not present in `{manifest_path}` "
|
|
|
|
+ )
|
|
return column_names
|
|
return column_names
|
|
|
|
|
|
def _builder_from_manifest(self, manifest_path: str) -> DataPipelineBuilder:
|
|
def _builder_from_manifest(self, manifest_path: str) -> DataPipelineBuilder:
|
|
@@ -230,7 +244,9 @@ class UnityDataLoader:
|
|
# Split each text line into its fields.
|
|
# Split each text line into its fields.
|
|
fields = self._read_column_names(manifest_path)
|
|
fields = self._read_column_names(manifest_path)
|
|
logger.debug(f"Column names: {fields}")
|
|
logger.debug(f"Column names: {fields}")
|
|
- txt_splitter = StrSplitter(sep=self.MANIFEST_COLUMN_SEP, names=fields, indices=[], exclude=True)
|
|
|
|
|
|
+ txt_splitter = StrSplitter(
|
|
|
|
+ sep=self.MANIFEST_COLUMN_SEP, names=fields, indices=[], exclude=True
|
|
|
|
+ )
|
|
pipeline.map(
|
|
pipeline.map(
|
|
txt_splitter,
|
|
txt_splitter,
|
|
selector=self.ROOT_COLUMN,
|
|
selector=self.ROOT_COLUMN,
|
|
@@ -244,7 +260,10 @@ class UnityDataLoader:
|
|
Picks samples from per-manifest pipelines in a round-robin order"""
|
|
Picks samples from per-manifest pipelines in a round-robin order"""
|
|
# TODO: add the ability to upsample/downsample manifests
|
|
# TODO: add the ability to upsample/downsample manifests
|
|
logger.info(f"Aggregating data from {len(self.manifest_paths)} manifests")
|
|
logger.info(f"Aggregating data from {len(self.manifest_paths)} manifests")
|
|
- builders = [self._builder_from_manifest(manifest_path=path) for path in self.manifest_paths]
|
|
|
|
|
|
+ builders = [
|
|
|
|
+ self._builder_from_manifest(manifest_path=path)
|
|
|
|
+ for path in self.manifest_paths
|
|
|
|
+ ]
|
|
pipelines = [builder.and_return() for builder in builders]
|
|
pipelines = [builder.and_return() for builder in builders]
|
|
return DataPipeline.round_robin(pipelines=pipelines)
|
|
return DataPipeline.round_robin(pipelines=pipelines)
|
|
|
|
|
|
@@ -276,7 +295,7 @@ class UnityDataLoader:
|
|
channel_last=True, # audio channel is the last dimension in the waveform
|
|
channel_last=True, # audio channel is the last dimension in the waveform
|
|
standardize=self.config.audio.fbanks_standardize_audio,
|
|
standardize=self.config.audio.fbanks_standardize_audio,
|
|
keep_waveform=False,
|
|
keep_waveform=False,
|
|
- device=self.target_device,
|
|
|
|
|
|
+ device=self.CPU_DEVICE, # avoid uncontrolled memory cons on GPUs
|
|
dtype=self.float_dtype,
|
|
dtype=self.float_dtype,
|
|
)
|
|
)
|
|
builder.map(
|
|
builder.map(
|
|
@@ -286,7 +305,9 @@ class UnityDataLoader:
|
|
)
|
|
)
|
|
return builder
|
|
return builder
|
|
|
|
|
|
- def _attach_target_tokens(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
|
|
|
|
|
|
+ def _attach_target_tokens(
|
|
|
|
+ self, builder: DataPipelineBuilder
|
|
|
|
+ ) -> DataPipelineBuilder:
|
|
# Convert `raw_tgt_text` to (full) target tokenized sequences:
|
|
# Convert `raw_tgt_text` to (full) target tokenized sequences:
|
|
# <eos> <lang_tok> <tokens .. > <eos>
|
|
# <eos> <lang_tok> <tokens .. > <eos>
|
|
# Lang tokens change between rows, so can't use static encoder
|
|
# Lang tokens change between rows, so can't use static encoder
|
|
@@ -305,7 +326,10 @@ class UnityDataLoader:
|
|
# 3) Not a computational blocker
|
|
# 3) Not a computational blocker
|
|
convert_to_units = lambda units_str: ( # noqa: E731
|
|
convert_to_units = lambda units_str: ( # noqa: E731
|
|
torch.LongTensor(
|
|
torch.LongTensor(
|
|
- [int(unit_id) + 4 for unit_id in units_str.rstrip().bytes().decode("utf-8").split()]
|
|
|
|
|
|
+ [
|
|
|
|
+ int(unit_id) + 4
|
|
|
|
+ for unit_id in units_str.rstrip().bytes().decode("utf-8").split()
|
|
|
|
+ ]
|
|
+ [self.unit_tokenizer.vocab_info.eos_idx]
|
|
+ [self.unit_tokenizer.vocab_info.eos_idx]
|
|
)
|
|
)
|
|
)
|
|
)
|
|
@@ -340,23 +364,43 @@ class UnityDataLoader:
|
|
|
|
|
|
def _is_long_sample(self, sample: Any) -> bool:
|
|
def _is_long_sample(self, sample: Any) -> bool:
|
|
# input audio length
|
|
# input audio length
|
|
- if self._get_input_audio_seconds(sample) > self.config.max_seconds_per_input_audio:
|
|
|
|
|
|
+ if (
|
|
|
|
+ self._get_input_audio_seconds(sample)
|
|
|
|
+ > self.config.max_seconds_per_input_audio
|
|
|
|
+ ):
|
|
return True
|
|
return True
|
|
|
|
|
|
# target text tokens
|
|
# target text tokens
|
|
- num_tgt_text_tokens = sample[self.ROOT_COLUMN][self.TARGET_TEXT_COLUMN].shape[-1]
|
|
|
|
|
|
+ num_tgt_text_tokens = sample[self.ROOT_COLUMN][self.TARGET_TEXT_COLUMN].shape[
|
|
|
|
+ -1
|
|
|
|
+ ]
|
|
if num_tgt_text_tokens > self.config.max_tgt_text_tokens_per_sample:
|
|
if num_tgt_text_tokens > self.config.max_tgt_text_tokens_per_sample:
|
|
return True
|
|
return True
|
|
|
|
|
|
# target units
|
|
# target units
|
|
- num_tgt_units = sample[self.ROOT_COLUMN][self.TARGET_UNITS_COLUMN].shape[-1] # target units
|
|
|
|
|
|
+ num_tgt_units = sample[self.ROOT_COLUMN][self.TARGET_UNITS_COLUMN].shape[
|
|
|
|
+ -1
|
|
|
|
+ ] # target units
|
|
if num_tgt_units > self.config.max_units_per_sample:
|
|
if num_tgt_units > self.config.max_units_per_sample:
|
|
return True
|
|
return True
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
+ def _nans_in_fbanks(self, sample: Any) -> bool:
|
|
|
|
+ """Tells if NaNs present in fbank"""
|
|
|
|
+ fbank = sample[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
|
|
|
|
+ has_nans: bool = torch.any(torch.isnan(fbank)).item() # type: ignore
|
|
|
|
+ if has_nans:
|
|
|
|
+ logger.warning("Sample fbank contains NaNs. Skipping")
|
|
|
|
+ return has_nans
|
|
|
|
+
|
|
def _filter_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
|
|
def _filter_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
|
|
- # Drop long samples
|
|
|
|
- builder.filter(lambda sample: not self._is_long_sample(sample))
|
|
|
|
|
|
+ # Drop:
|
|
|
|
+ # - "long" samples
|
|
|
|
+ # - samples with fbanks that contain NaNs
|
|
|
|
+ builder.filter(
|
|
|
|
+ lambda sample: not self._is_long_sample(sample)
|
|
|
|
+ and not self._nans_in_fbanks(sample)
|
|
|
|
+ )
|
|
return builder
|
|
return builder
|
|
|
|
|
|
def _batch_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
|
|
def _batch_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
|
|
@@ -415,8 +459,12 @@ class UnityDataLoader:
|
|
prev_output_tokens = prev_output_tokens[:, :-1]
|
|
prev_output_tokens = prev_output_tokens[:, :-1]
|
|
|
|
|
|
target_tokens = tokens[:, 1:]
|
|
target_tokens = tokens[:, 1:]
|
|
- assert torch.equal(torch.count_nonzero(prev_output_tokens != pad_idx, dim=1), target_lengths)
|
|
|
|
- assert torch.equal(torch.count_nonzero(target_tokens != pad_idx, dim=1), target_lengths)
|
|
|
|
|
|
+ assert torch.equal(
|
|
|
|
+ torch.count_nonzero(prev_output_tokens != pad_idx, dim=1), target_lengths
|
|
|
|
+ )
|
|
|
|
+ assert torch.equal(
|
|
|
|
+ torch.count_nonzero(target_tokens != pad_idx, dim=1), target_lengths
|
|
|
|
+ )
|
|
return prev_output_tokens, target_tokens, target_lengths
|
|
return prev_output_tokens, target_tokens, target_lengths
|
|
|
|
|
|
def _get_text_to_units_batch(self, raw_batch: Any) -> SeqsBatch:
|
|
def _get_text_to_units_batch(self, raw_batch: Any) -> SeqsBatch:
|
|
@@ -448,7 +496,9 @@ class UnityDataLoader:
|
|
prefix_tokens=prefix_tokens.to(self.target_device),
|
|
prefix_tokens=prefix_tokens.to(self.target_device),
|
|
)
|
|
)
|
|
|
|
|
|
- def _get_speech_src_tokens_and_lengths(self, raw_batch: Any) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
+ def _get_speech_src_tokens_and_lengths(
|
|
|
|
+ self, raw_batch: Any
|
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
fbanks = raw_batch[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
|
|
fbanks = raw_batch[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
|
|
return fbanks["seqs"].to(self.float_dtype), fbanks["seq_lens"]
|
|
return fbanks["seqs"].to(self.float_dtype), fbanks["seq_lens"]
|
|
|
|
|
|
@@ -471,7 +521,9 @@ class UnityDataLoader:
|
|
pad_idx=pad_idx,
|
|
pad_idx=pad_idx,
|
|
eos_idx=eos_idx,
|
|
eos_idx=eos_idx,
|
|
)
|
|
)
|
|
- src_tokens, src_lengths = self._get_speech_src_tokens_and_lengths(raw_batch=raw_batch)
|
|
|
|
|
|
+ src_tokens, src_lengths = self._get_speech_src_tokens_and_lengths(
|
|
|
|
+ raw_batch=raw_batch
|
|
|
|
+ )
|
|
|
|
|
|
return SeqsBatch(
|
|
return SeqsBatch(
|
|
src_tokens=src_tokens.to(self.target_device),
|
|
src_tokens=src_tokens.to(self.target_device),
|