dataloader.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. import os
  8. from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
  9. import ctypes
  10. import torch
  11. from m4t_scripts.train.configs import AudioProcessingConfig, DataLoadingConfig
  12. from torch import Tensor
  13. from fairseq2.data import (
  14. CollateOptionsOverride,
  15. Collater,
  16. DataPipeline,
  17. DataPipelineBuilder,
  18. FileMapper,
  19. )
  20. from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
  21. from fairseq2.data.text import SentencePieceEncoder, StrSplitter, read_text
  22. from fairseq2.models.nllb.tokenizer import NllbTokenizer
  23. from seamless_communication.models.tokenizer import SPMTokenizer
  24. from seamless_communication.models.unity import (
  25. UnitTokenizer,
  26. load_unity_text_tokenizer,
  27. load_unity_unit_tokenizer,
  28. )
  29. logger = logging.getLogger(__name__)
  30. class SeqsBatch(NamedTuple):
  31. src_tokens: Optional[Tensor]
  32. src_lengths: Optional[Tensor]
  33. target_tokens: Tensor
  34. prev_output_tokens: Tensor
  35. target_lengths: Tensor
  36. prefix_tokens: Optional[Tensor]
  37. class MultimodalSeqsBatch(NamedTuple):
  38. speech_to_text: SeqsBatch
  39. text_to_units: SeqsBatch
  40. class UnityDataLoader:
  41. CPU_DEVICE = torch.device("cpu")
  42. MANIFEST_EXT = ".tsv"
  43. MANIFEST_COLUMN_SEP = "\t"
  44. AUDIO_COLUMN_NAME = "audio"
  45. TARGET_TEXT_COLUMN = "raw_tgt_text"
  46. TARGET_UNITS_COLUMN = "tgt_text"
  47. TARGET_LANG_COLUMN = "tgt_lang"
  48. ROOT_COLUMN = "_"
  49. BATCH_WIDTH_STEP = 8
  50. def __init__(
  51. self,
  52. config: DataLoadingConfig,
  53. rank: int = 0,
  54. world_size: int = 1,
  55. target_device: torch.device = CPU_DEVICE,
  56. float_dtype: torch.dtype = torch.float16, # training/inference precision
  57. ):
  58. self.config = config
  59. self.rank = rank
  60. self.world_size = world_size
  61. self.target_device = target_device
  62. self.float_dtype = float_dtype
  63. self._set_mkl_num_threads()
  64. self.manifest_paths = list(self._iterate_manifest_paths())
  65. self.text_tokenizer = self._init_text_tokenizer()
  66. self.unit_tokenizer = self._init_unit_tokenizer()
  67. self.spm_encoder = SentencePieceEncoder(
  68. model=self.text_tokenizer.model, suffix_tokens=["</s>"]
  69. )
  70. self.text_prefix_tokens = self._build_text_tgt_prefixes()
  71. self.unit_prefix_tokens = self._build_unit_tgt_prefixes()
  72. if self.config.fixed_batch_size is None:
  73. self.tgt_text_batch_shapes = self._calculate_tgt_text_batch_shapes()
  74. else:
  75. self.tgt_text_batch_shapes = []
  76. self.pipeline = self._build_pipeline()
  77. @classmethod
  78. def _set_mkl_num_threads(cls):
  79. """Setting mkl num threads to 1, so that we don't get thread explosion."""
  80. mkl_rt = ctypes.CDLL("libmkl_rt.so")
  81. mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1)))
  82. def _calculate_tgt_text_batch_shapes(self) -> List[Tuple[int, int]]:
  83. max_seq_len = self.config.max_tgt_text_tokens_per_sample
  84. max_tokens_per_batch = self.config.max_tgt_text_tokens_per_batch
  85. assert max_tokens_per_batch is not None, "max_tokens_per_batch is not set"
  86. max_bsz = (
  87. self.config.max_batch_size
  88. if self.config.max_batch_size is not None
  89. else max_tokens_per_batch
  90. )
  91. step = self.BATCH_WIDTH_STEP
  92. bucket_sizes = []
  93. for seq_len in range(step, max(step, max_seq_len) + 1, step):
  94. bsz = max(1, max_tokens_per_batch // seq_len)
  95. if bsz > max_bsz:
  96. continue
  97. bucket_sizes.append((bsz, seq_len))
  98. return bucket_sizes
  99. def _build_text_tgt_prefixes(self) -> Dict[str, List[int]]:
  100. return {
  101. lang_tok: self.text_tokenizer.create_encoder(
  102. lang=lang_tok, mode="target"
  103. ).prefix_indices.tolist() # type:ignore
  104. for lang_tok in self.text_tokenizer.langs
  105. }
  106. def _build_unit_tgt_prefixes(self) -> Dict[str, List[int]]:
  107. assert self.unit_tokenizer.vocab_info.eos_idx is not None
  108. return {
  109. lang_tok: [
  110. self.unit_tokenizer.vocab_info.eos_idx,
  111. self.unit_tokenizer.lang_to_index(lang_tok),
  112. ]
  113. for lang_tok in self.unit_tokenizer.langs
  114. } # type: ignore
  115. def _init_text_tokenizer(self) -> Union[NllbTokenizer, SPMTokenizer]:
  116. if self.config.text_tokenization.from_model is not None:
  117. return load_unity_text_tokenizer(self.config.text_tokenization.from_model)
  118. else:
  119. assert self.config.text_tokenization.langtoks is not None
  120. assert self.config.text_tokenization.spm_path is not None
  121. return SPMTokenizer(
  122. pathname=self.config.text_tokenization.spm_path,
  123. langs=self.config.text_tokenization.langtoks,
  124. )
  125. def _init_unit_tokenizer(self) -> UnitTokenizer:
  126. if self.config.unit_tokenization.from_model is not None:
  127. return load_unity_unit_tokenizer(self.config.unit_tokenization.from_model)
  128. else:
  129. raise NotImplementedError("TBD")
  130. def _load_manifest_list_from_file(self) -> Iterator[str]:
  131. if self.config.manifest_list_path is not None:
  132. for line in open(self.config.manifest_list_path).readlines():
  133. line = line.split("#")[0].strip() # allow comments
  134. if line:
  135. yield line
  136. def _load_raw_manifest_list(self) -> List[str]:
  137. raw_list = []
  138. if self.config.manifest_list is not None:
  139. raw_list += self.config.manifest_list.strip().split(",")
  140. raw_list += list(self._load_manifest_list_from_file())
  141. return raw_list
  142. def _infer_manifest_full_path(self, manifest_name: str) -> str:
  143. full_path = manifest_name.strip()
  144. if self.config.manifest_path_prefix is not None:
  145. full_path = os.path.join(
  146. self.config.manifest_path_prefix.strip(), full_path
  147. )
  148. if not full_path.endswith(self.MANIFEST_EXT) and not os.path.exists(full_path):
  149. full_path += self.MANIFEST_EXT
  150. if not os.path.exists(full_path):
  151. raise FileNotFoundError(f"File not found {full_path}")
  152. return full_path
  153. def _iterate_manifest_paths(self, skip_missing_files: bool = True) -> Iterator[str]:
  154. """Yields full paths to manifests described in the data config.
  155. Check that each file exist.
  156. Expects *.tsv files"""
  157. raw_list = self._load_raw_manifest_list()
  158. for manifest_name in raw_list:
  159. try:
  160. full_path = self._infer_manifest_full_path(manifest_name=manifest_name)
  161. except FileNotFoundError:
  162. if skip_missing_files:
  163. logger.warning(f"Skipping manifest {manifest_name}, file not found")
  164. continue
  165. raise
  166. yield full_path
  167. def _read_column_names(self, manifest_path: str) -> List[str]:
  168. """Gets the order of columns in the manifest file.
  169. Also checks that expected columns are present."""
  170. with open(manifest_path, "r") as in_fp:
  171. column_names = in_fp.readline().strip().split("\t")
  172. for column in [
  173. self.AUDIO_COLUMN_NAME,
  174. self.TARGET_TEXT_COLUMN,
  175. self.TARGET_UNITS_COLUMN,
  176. self.TARGET_LANG_COLUMN,
  177. ]:
  178. if column not in column_names:
  179. raise ValueError(
  180. f"Column `{column}` is not present in `{manifest_path}` "
  181. )
  182. return column_names
  183. def _builder_from_manifest(self, manifest_path: str) -> DataPipelineBuilder:
  184. """Creates a data pipeline builder for the specified manifest_path file."""
  185. logger.debug(f"Initialiazing samples loader from {manifest_path}")
  186. # Memory map file and read it in text mode (skip empty lines if any).
  187. # Skip header.
  188. tsv_lines = (
  189. read_text(
  190. pathname=manifest_path,
  191. encoding="UTF-8",
  192. rtrim=True,
  193. skip_empty=True,
  194. memory_map=True,
  195. )
  196. .skip(1)
  197. .and_return()
  198. )
  199. # Assing column names:
  200. # line content: `_`
  201. # source manifest path: `manifest_path`
  202. # line number: `lineno`
  203. line_numbers = DataPipeline.count().and_return()
  204. filename_const = DataPipeline.constant(manifest_path).and_return()
  205. pipeline = DataPipeline.zip(
  206. [tsv_lines, filename_const, line_numbers],
  207. names=[self.ROOT_COLUMN, "manifest_path", "lineno"],
  208. zip_to_shortest=True,
  209. )
  210. # Read every `world_size`th line starting from `rank`th item in the file.
  211. pipeline.shard(self.rank, self.world_size)
  212. if self.config.shuffle_window is not None:
  213. pipeline.shuffle(self.config.shuffle_window)
  214. # Split each text line into its fields.
  215. fields = self._read_column_names(manifest_path)
  216. logger.debug(f"Column names: {fields}")
  217. txt_splitter = StrSplitter(
  218. sep=self.MANIFEST_COLUMN_SEP, names=fields, indices=[], exclude=True
  219. )
  220. pipeline.map(
  221. txt_splitter,
  222. selector=self.ROOT_COLUMN,
  223. num_parallel_calls=self.config.num_threads,
  224. )
  225. # And, create the pipeline for the TSV file.
  226. return pipeline
  227. def _get_manifest_funnel(self) -> DataPipelineBuilder:
  228. """Creates a joined pipeline from all manifests.
  229. Picks samples from per-manifest pipelines in a round-robin order"""
  230. # TODO: add the ability to upsample/downsample manifests
  231. logger.info(f"Aggregating data from {len(self.manifest_paths)} manifests")
  232. builders = [
  233. self._builder_from_manifest(manifest_path=path)
  234. for path in self.manifest_paths
  235. ]
  236. pipelines = [builder.and_return() for builder in builders]
  237. return DataPipeline.round_robin(pipelines=pipelines)
  238. def _attach_audio(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
  239. """Attaches audio waveforms and fbanks from linked autio files"""
  240. audio_selector = f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}"
  241. audio_data_selector = f"{audio_selector}.data"
  242. # Memory map each `audio_file`
  243. map_file = FileMapper(self.config.audio.audio_root_dir, cached_fd_count=100)
  244. builder.map(
  245. map_file,
  246. selector=audio_selector,
  247. num_parallel_calls=self.config.num_threads,
  248. )
  249. # Decode each mmap'ed audio file using libsndfile.
  250. decode_audio = AudioDecoder(dtype=torch.float32)
  251. builder.map(
  252. decode_audio,
  253. selector=audio_data_selector,
  254. num_parallel_calls=self.config.num_threads,
  255. )
  256. # And, convert from waveform to log-mel filterbank
  257. convert_to_fbank = WaveformToFbankConverter(
  258. num_mel_bins=self.config.audio.fbanks_num_mel_bins,
  259. waveform_scale=self.config.audio.fbanks_waveform_scale,
  260. channel_last=True, # audio channel is the last dimension in the waveform
  261. standardize=self.config.audio.fbanks_standardize_audio,
  262. keep_waveform=False,
  263. device=self.CPU_DEVICE, # avoid uncontrolled memory cons on GPUs
  264. dtype=self.float_dtype,
  265. )
  266. builder.map(
  267. convert_to_fbank,
  268. selector=audio_data_selector,
  269. num_parallel_calls=self.config.num_threads,
  270. )
  271. return builder
  272. def _attach_target_tokens(
  273. self, builder: DataPipelineBuilder
  274. ) -> DataPipelineBuilder:
  275. # Convert `raw_tgt_text` to (full) target tokenized sequences:
  276. # <eos> <lang_tok> <tokens .. > <eos>
  277. # Lang tokens change between rows, so can't use static encoder
  278. builder.map(
  279. [self.spm_encoder],
  280. selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
  281. num_parallel_calls=self.config.num_threads,
  282. )
  283. # Convert the `tgt_text` field into a unit tensor + EOS
  284. # TODO: We should use unit tokenizer.
  285. # Motivation for the current implementation:
  286. # 1) lang_tok can change between rows.
  287. # If we want to attach lang_token_id here, we need a way to join values from two columns
  288. # 2) StrToTensorConverter doesn't allow suffix tokens. Adding it later is less covenient.
  289. # 3) Not a computational blocker
  290. convert_to_units = lambda units_str: ( # noqa: E731
  291. torch.LongTensor(
  292. [
  293. int(unit_id) + 4
  294. for unit_id in units_str.rstrip().bytes().decode("utf-8").split()
  295. ]
  296. + [self.unit_tokenizer.vocab_info.eos_idx]
  297. )
  298. )
  299. builder.map(
  300. [convert_to_units],
  301. selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
  302. num_parallel_calls=self.config.num_threads,
  303. )
  304. # prefixes for tokenized texts and speech units (<eos> <lang_tok>)
  305. prefix_builder = lambda lang_tok: torch.LongTensor( # noqa: E731
  306. [
  307. self.text_prefix_tokens[lang_tok.bytes().decode("utf8")],
  308. self.unit_prefix_tokens[lang_tok.bytes().decode("utf8")],
  309. ]
  310. )
  311. builder.map(
  312. [prefix_builder],
  313. selector=f"{self.ROOT_COLUMN}.{self.TARGET_LANG_COLUMN}",
  314. num_parallel_calls=self.config.num_threads,
  315. )
  316. return builder
  317. def _get_input_audio_seconds(self, sample: Any) -> float:
  318. audio_data = sample[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]
  319. input_audio_sample_rate = audio_data["sample_rate"]
  320. num_fbanks = max(audio_data["fbank"].shape) # not guessing the dim order
  321. # TODO: clarify where '* 2' comes from
  322. waveform_length = num_fbanks * self.config.audio.fbanks_num_mel_bins * 2
  323. input_audio_seconds = waveform_length / input_audio_sample_rate
  324. return input_audio_seconds
  325. def _is_long_sample(self, sample: Any) -> bool:
  326. # input audio length
  327. if (
  328. self._get_input_audio_seconds(sample)
  329. > self.config.max_seconds_per_input_audio
  330. ):
  331. return True
  332. # target text tokens
  333. num_tgt_text_tokens = sample[self.ROOT_COLUMN][self.TARGET_TEXT_COLUMN].shape[
  334. -1
  335. ]
  336. if num_tgt_text_tokens > self.config.max_tgt_text_tokens_per_sample:
  337. return True
  338. # target units
  339. num_tgt_units = sample[self.ROOT_COLUMN][self.TARGET_UNITS_COLUMN].shape[
  340. -1
  341. ] # target units
  342. if num_tgt_units > self.config.max_units_per_sample:
  343. return True
  344. return False
  345. def _nans_in_fbanks(self, sample: Any) -> bool:
  346. """Tells if NaNs present in fbank"""
  347. fbank = sample[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
  348. has_nans: bool = torch.any(torch.isnan(fbank)).item() # type: ignore
  349. if has_nans:
  350. logger.warning("Sample fbank contains NaNs. Skipping")
  351. return has_nans
  352. def _filter_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
  353. # Drop:
  354. # - "long" samples
  355. # - samples with fbanks that contain NaNs
  356. builder.filter(
  357. lambda sample: not self._is_long_sample(sample)
  358. and not self._nans_in_fbanks(sample)
  359. )
  360. return builder
  361. def _batch_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
  362. if self.config.fixed_batch_size is not None:
  363. builder.bucket(bucket_size=self.config.fixed_batch_size)
  364. elif self.tgt_text_batch_shapes is not None:
  365. builder.bucket_by_length(
  366. self.tgt_text_batch_shapes,
  367. selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
  368. )
  369. else:
  370. raise ValueError("Unclear batching strategy")
  371. # Collate bucketed elements into a batch.
  372. collater = Collater(
  373. pad_to_multiple=1,
  374. overrides=[
  375. CollateOptionsOverride(
  376. selector=f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}.data.fbank",
  377. pad_idx=self.config.fbank_feats_pad_idx,
  378. ),
  379. CollateOptionsOverride(
  380. selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
  381. pad_idx=self.text_tokenizer.vocab_info.pad_idx,
  382. ),
  383. CollateOptionsOverride(
  384. selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
  385. pad_idx=self.unit_tokenizer.vocab_info.pad_idx,
  386. ),
  387. ],
  388. )
  389. builder.map(collater, num_parallel_calls=self.config.num_threads)
  390. if self.config.prefech_batches is not None:
  391. builder.prefetch(self.config.prefech_batches)
  392. return builder
  393. def _build_pipeline(self) -> DataPipeline:
  394. data = self._get_manifest_funnel()
  395. data = self._attach_audio(data)
  396. data = self._attach_target_tokens(data)
  397. data = self._filter_samples(data)
  398. batches = self._batch_samples(data)
  399. return batches.and_return()
  400. def _gen_prev_toks_target_toks_target_lens(
  401. self, seqs: Any, prefix_tokens: torch.Tensor, pad_idx: int, eos_idx: int
  402. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  403. # <eos> <lang_tok> ... <eos> <pad>*
  404. tokens = torch.cat((prefix_tokens, seqs["seqs"]), 1)
  405. target_lengths = seqs["seq_lens"] + 1 # + <leng_tok>
  406. prev_output_tokens = torch.clone(tokens)
  407. # replace last <eos> with <pad> and remove last column
  408. mask = prev_output_tokens == eos_idx
  409. mask[:, 0] = 0
  410. prev_output_tokens[mask] = pad_idx
  411. prev_output_tokens = prev_output_tokens[:, :-1]
  412. target_tokens = tokens[:, 1:]
  413. assert torch.equal(
  414. torch.count_nonzero(prev_output_tokens != pad_idx, dim=1), target_lengths
  415. )
  416. assert torch.equal(
  417. torch.count_nonzero(target_tokens != pad_idx, dim=1), target_lengths
  418. )
  419. return prev_output_tokens, target_tokens, target_lengths
  420. def _get_text_to_units_batch(self, raw_batch: Any) -> SeqsBatch:
  421. root = raw_batch[self.ROOT_COLUMN]
  422. seqs = root[self.TARGET_UNITS_COLUMN]
  423. prefix_tokens = root[self.TARGET_LANG_COLUMN][:, 1, :]
  424. pad_idx = self.unit_tokenizer.vocab_info.pad_idx
  425. eos_idx = self.unit_tokenizer.vocab_info.eos_idx
  426. assert pad_idx is not None
  427. assert eos_idx is not None
  428. (
  429. prev_output_tokens,
  430. target_tokens,
  431. target_lengths,
  432. ) = self._gen_prev_toks_target_toks_target_lens(
  433. seqs=seqs,
  434. prefix_tokens=prefix_tokens,
  435. pad_idx=pad_idx,
  436. eos_idx=eos_idx,
  437. )
  438. return SeqsBatch(
  439. src_tokens=None,
  440. src_lengths=None,
  441. target_tokens=target_tokens.to(self.target_device),
  442. prev_output_tokens=prev_output_tokens.to(self.target_device),
  443. target_lengths=target_lengths.to(self.target_device),
  444. prefix_tokens=prefix_tokens.to(self.target_device),
  445. )
  446. def _get_speech_src_tokens_and_lengths(
  447. self, raw_batch: Any
  448. ) -> Tuple[torch.Tensor, torch.Tensor]:
  449. fbanks = raw_batch[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
  450. return fbanks["seqs"].to(self.float_dtype), fbanks["seq_lens"]
  451. def _get_speech_to_text_batch(self, raw_batch: Any) -> SeqsBatch:
  452. root = raw_batch[self.ROOT_COLUMN]
  453. seqs = root[self.TARGET_TEXT_COLUMN]
  454. prefix_tokens = root[self.TARGET_LANG_COLUMN][:, 0, :]
  455. pad_idx = self.text_tokenizer.vocab_info.pad_idx
  456. assert pad_idx is not None
  457. eos_idx = self.text_tokenizer.vocab_info.eos_idx
  458. assert eos_idx is not None
  459. (
  460. prev_output_tokens,
  461. target_tokens,
  462. target_lengths,
  463. ) = self._gen_prev_toks_target_toks_target_lens(
  464. seqs=seqs,
  465. prefix_tokens=prefix_tokens,
  466. pad_idx=pad_idx,
  467. eos_idx=eos_idx,
  468. )
  469. src_tokens, src_lengths = self._get_speech_src_tokens_and_lengths(
  470. raw_batch=raw_batch
  471. )
  472. return SeqsBatch(
  473. src_tokens=src_tokens.to(self.target_device),
  474. src_lengths=src_lengths.to(self.target_device),
  475. target_tokens=target_tokens.to(self.target_device),
  476. prev_output_tokens=prev_output_tokens.to(self.target_device),
  477. target_lengths=target_lengths.to(self.target_device),
  478. prefix_tokens=prefix_tokens.to(self.target_device),
  479. )
  480. def _convert_to_mulitmodal_seqs_batch(self, raw_batch: Any) -> MultimodalSeqsBatch:
  481. return MultimodalSeqsBatch(
  482. speech_to_text=self._get_speech_to_text_batch(raw_batch=raw_batch),
  483. text_to_units=self._get_text_to_units_batch(raw_batch=raw_batch),
  484. )
  485. def iterate_batches(self) -> Iterator[MultimodalSeqsBatch]:
  486. for raw_batch in self.pipeline:
  487. yield self._convert_to_mulitmodal_seqs_batch(raw_batch)
  488. def reset(self) -> None:
  489. self.pipeline.reset()
  490. if __name__ == "__main__":
  491. logging.basicConfig(
  492. level=logging.INFO,
  493. format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
  494. )
  495. config = DataLoadingConfig(
  496. audio=AudioProcessingConfig(
  497. audio_root_dir="/fsx-ust/data/audio_zips/",
  498. ),
  499. manifest_path_prefix="/fsx-ust/spopuri/datasets/S2ST/V1/M4T_V1_phase2/primary",
  500. manifest_list_path="/data/home/mavlyutov/train_manifests.txt",
  501. shuffle_window=1000,
  502. num_threads=5,
  503. )
  504. loader = UnityDataLoader(config=config, target_device=torch.device("cpu"))
  505. for idx, batch in enumerate(loader.iterate_batches()):
  506. if idx % 10 == 0:
  507. assert batch.speech_to_text.src_tokens is not None
  508. print(batch.speech_to_text.src_tokens.shape)
  509. logger.info(f".. pulled {idx} batches")
  510. if idx > 1000:
  511. break