dataloader.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. import os
  8. from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
  9. import ctypes
  10. import torch
  11. from m4t_scripts.train.configs import AudioProcessingConfig, DataLoadingConfig
  12. from torch import Tensor
  13. from fairseq2.data import (
  14. CollateOptionsOverride,
  15. Collater,
  16. DataPipeline,
  17. DataPipelineBuilder,
  18. FileMapper,
  19. )
  20. from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
  21. from fairseq2.data.text import SentencePieceEncoder, StrSplitter, read_text
  22. from fairseq2.models.nllb.tokenizer import NllbTokenizer
  23. from seamless_communication.models.tokenizer import SPMTokenizer
  24. from seamless_communication.models.unity import (
  25. UnitTokenizer,
  26. load_unity_text_tokenizer,
  27. load_unity_unit_tokenizer,
  28. )
  29. logger = logging.getLogger(__name__)
  30. class SeqsBatch(NamedTuple):
  31. src_tokens: Optional[Tensor]
  32. src_lengths: Optional[Tensor]
  33. target_tokens: Tensor
  34. prev_output_tokens: Tensor
  35. target_lengths: Tensor
  36. prefix_tokens: Optional[Tensor]
  37. class MultimodalSeqsBatch(NamedTuple):
  38. speech_to_text: SeqsBatch
  39. text_to_units: SeqsBatch
  40. class UnityDataLoader:
  41. CPU_DEVICE = torch.device("cpu")
  42. MANIFEST_EXT = ".tsv"
  43. MANIFEST_COLUMN_SEP = "\t"
  44. AUDIO_COLUMN_NAME = "audio"
  45. TARGET_TEXT_COLUMN = "raw_tgt_text"
  46. TARGET_UNITS_COLUMN = "tgt_text"
  47. TARGET_LANG_COLUMN = "tgt_lang"
  48. ROOT_COLUMN = "_"
  49. BATCH_WIDTH_STEP = 8
  50. def __init__(
  51. self,
  52. config: DataLoadingConfig,
  53. rank: int = 0,
  54. world_size: int = 1,
  55. target_device: torch.device = CPU_DEVICE,
  56. float_dtype: torch.dtype = torch.float16, # training/inference precision
  57. ):
  58. self.config = config
  59. self.rank = rank
  60. self.world_size = world_size
  61. self.target_device = target_device
  62. self.float_dtype = float_dtype
  63. self._set_mkl_num_threads()
  64. self.manifest_paths = list(self._iterate_manifest_paths())
  65. self.text_tokenizer = self._init_text_tokenizer()
  66. self.unit_tokenizer = self._init_unit_tokenizer()
  67. self.spm_encoder = SentencePieceEncoder(model=self.text_tokenizer.model, suffix_tokens=["</s>"])
  68. self.text_prefix_tokens = self._build_text_tgt_prefixes()
  69. self.unit_prefix_tokens = self._build_unit_tgt_prefixes()
  70. if self.config.fixed_batch_size is None:
  71. self.tgt_text_batch_shapes = self._calculate_tgt_text_batch_shapes()
  72. else:
  73. self.tgt_text_batch_shapes = []
  74. self.pipeline = self._build_pipeline()
  75. @classmethod
  76. def _set_mkl_num_threads(cls):
  77. """ Setting mkl num threads to 1, so that we don't get thread explosion."""
  78. mkl_rt = ctypes.CDLL('libmkl_rt.so')
  79. mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1)))
  80. def _calculate_tgt_text_batch_shapes(self) -> List[Tuple[int, int]]:
  81. max_seq_len = self.config.max_tgt_text_tokens_per_sample
  82. max_tokens_per_batch = self.config.max_tgt_text_tokens_per_batch
  83. assert max_tokens_per_batch is not None, "max_tokens_per_batch is not set"
  84. step = self.BATCH_WIDTH_STEP
  85. bucket_sizes = []
  86. for seq_len in range(step, max(step, max_seq_len) + 1, step):
  87. bsz = max(1, max_tokens_per_batch // seq_len)
  88. bucket_sizes.append((bsz, seq_len))
  89. return bucket_sizes
  90. def _build_text_tgt_prefixes(self) -> Dict[str, List[int]]:
  91. return {
  92. lang_tok: self.text_tokenizer.create_encoder(
  93. lang=lang_tok, mode="target"
  94. ).prefix_indices.tolist() # type:ignore
  95. for lang_tok in self.text_tokenizer.langs
  96. }
  97. def _build_unit_tgt_prefixes(self) -> Dict[str, List[int]]:
  98. assert self.unit_tokenizer.vocab_info.eos_idx is not None
  99. return {
  100. lang_tok: [
  101. self.unit_tokenizer.vocab_info.eos_idx,
  102. self.unit_tokenizer.lang_to_index(lang_tok),
  103. ]
  104. for lang_tok in self.unit_tokenizer.langs
  105. } # type: ignore
  106. def _init_text_tokenizer(self) -> Union[NllbTokenizer, SPMTokenizer]:
  107. if self.config.text_tokenization.from_model is not None:
  108. return load_unity_text_tokenizer(self.config.text_tokenization.from_model)
  109. else:
  110. assert self.config.text_tokenization.langtoks is not None
  111. assert self.config.text_tokenization.spm_path is not None
  112. return SPMTokenizer(
  113. pathname=self.config.text_tokenization.spm_path, langs=self.config.text_tokenization.langtoks
  114. )
  115. def _init_unit_tokenizer(self) -> UnitTokenizer:
  116. if self.config.unit_tokenization.from_model is not None:
  117. return load_unity_unit_tokenizer(self.config.unit_tokenization.from_model)
  118. else:
  119. raise NotImplementedError("TBD")
  120. def _load_manifest_list_from_file(self) -> Iterator[str]:
  121. if self.config.manifest_list_path is not None:
  122. for line in open(self.config.manifest_list_path).readlines():
  123. line = line.split("#")[0].strip() # allow comments
  124. if line:
  125. yield line
  126. def _load_raw_manifest_list(self) -> List[str]:
  127. raw_list = []
  128. if self.config.manifest_list is not None:
  129. raw_list += self.config.manifest_list.strip().split(",")
  130. raw_list += list(self._load_manifest_list_from_file())
  131. return raw_list
  132. def _infer_manifest_full_path(self, manifest_name: str) -> str:
  133. full_path = manifest_name.strip()
  134. if self.config.manifest_path_prefix is not None:
  135. full_path = os.path.join(self.config.manifest_path_prefix.strip(), full_path)
  136. if not full_path.endswith(self.MANIFEST_EXT) and not os.path.exists(full_path):
  137. full_path += self.MANIFEST_EXT
  138. if not os.path.exists(full_path):
  139. raise FileNotFoundError(f"File not found {full_path}")
  140. return full_path
  141. def _iterate_manifest_paths(self, skip_missing_files: bool = True) -> Iterator[str]:
  142. """Yields full paths to manifests described in the data config.
  143. Check that each file exist.
  144. Expects *.tsv files"""
  145. raw_list = self._load_raw_manifest_list()
  146. for manifest_name in raw_list:
  147. try:
  148. full_path = self._infer_manifest_full_path(manifest_name=manifest_name)
  149. except FileNotFoundError:
  150. if skip_missing_files:
  151. logger.warning(f"Skipping manifest {manifest_name}, file not found")
  152. continue
  153. raise
  154. yield full_path
  155. def _read_column_names(self, manifest_path: str) -> List[str]:
  156. """Gets the order of columns in the manifest file.
  157. Also checks that expected columns are present."""
  158. with open(manifest_path, "r") as in_fp:
  159. column_names = in_fp.readline().strip().split("\t")
  160. for column in [
  161. self.AUDIO_COLUMN_NAME,
  162. self.TARGET_TEXT_COLUMN,
  163. self.TARGET_UNITS_COLUMN,
  164. self.TARGET_LANG_COLUMN,
  165. ]:
  166. if column not in column_names:
  167. raise ValueError(f"Column `{column}` is not present in `{manifest_path}` ")
  168. return column_names
  169. def _builder_from_manifest(self, manifest_path: str) -> DataPipelineBuilder:
  170. """Creates a data pipeline builder for the specified manifest_path file."""
  171. logger.debug(f"Initialiazing samples loader from {manifest_path}")
  172. # Memory map file and read it in text mode (skip empty lines if any).
  173. # Skip header.
  174. tsv_lines = (
  175. read_text(
  176. pathname=manifest_path,
  177. encoding="UTF-8",
  178. rtrim=True,
  179. skip_empty=True,
  180. memory_map=True,
  181. )
  182. .skip(1)
  183. .and_return()
  184. )
  185. # Assing column names:
  186. # line content: `_`
  187. # source manifest path: `manifest_path`
  188. # line number: `lineno`
  189. line_numbers = DataPipeline.count().and_return()
  190. filename_const = DataPipeline.constant(manifest_path).and_return()
  191. pipeline = DataPipeline.zip(
  192. [tsv_lines, filename_const, line_numbers],
  193. names=[self.ROOT_COLUMN, "manifest_path", "lineno"],
  194. zip_to_shortest=True,
  195. )
  196. # Read every `world_size`th line starting from `rank`th item in the file.
  197. pipeline.shard(self.rank, self.world_size)
  198. if self.config.shuffle_window is not None:
  199. pipeline.shuffle(self.config.shuffle_window)
  200. # Split each text line into its fields.
  201. fields = self._read_column_names(manifest_path)
  202. logger.debug(f"Column names: {fields}")
  203. txt_splitter = StrSplitter(sep=self.MANIFEST_COLUMN_SEP, names=fields, indices=[], exclude=True)
  204. pipeline.map(
  205. txt_splitter,
  206. selector=self.ROOT_COLUMN,
  207. num_parallel_calls=self.config.num_threads,
  208. )
  209. # And, create the pipeline for the TSV file.
  210. return pipeline
  211. def _get_manifest_funnel(self) -> DataPipelineBuilder:
  212. """Creates a joined pipeline from all manifests.
  213. Picks samples from per-manifest pipelines in a round-robin order"""
  214. # TODO: add the ability to upsample/downsample manifests
  215. logger.info(f"Aggregating data from {len(self.manifest_paths)} manifests")
  216. builders = [self._builder_from_manifest(manifest_path=path) for path in self.manifest_paths]
  217. pipelines = [builder.and_return() for builder in builders]
  218. return DataPipeline.round_robin(pipelines=pipelines)
  219. def _attach_audio(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
  220. """Attaches audio waveforms and fbanks from linked autio files"""
  221. audio_selector = f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}"
  222. audio_data_selector = f"{audio_selector}.data"
  223. # Memory map each `audio_file`
  224. map_file = FileMapper(self.config.audio.audio_root_dir, cached_fd_count=100)
  225. builder.map(
  226. map_file,
  227. selector=audio_selector,
  228. num_parallel_calls=self.config.num_threads,
  229. )
  230. # Decode each mmap'ed audio file using libsndfile.
  231. decode_audio = AudioDecoder(dtype=torch.float32)
  232. builder.map(
  233. decode_audio,
  234. selector=audio_data_selector,
  235. num_parallel_calls=self.config.num_threads,
  236. )
  237. # And, convert from waveform to log-mel filterbank
  238. convert_to_fbank = WaveformToFbankConverter(
  239. num_mel_bins=self.config.audio.fbanks_num_mel_bins,
  240. waveform_scale=self.config.audio.fbanks_waveform_scale,
  241. channel_last=True, # audio channel is the last dimension in the waveform
  242. standardize=self.config.audio.fbanks_standardize_audio,
  243. keep_waveform=False,
  244. device=self.target_device,
  245. dtype=self.float_dtype,
  246. )
  247. builder.map(
  248. convert_to_fbank,
  249. selector=audio_data_selector,
  250. num_parallel_calls=self.config.num_threads,
  251. )
  252. return builder
  253. def _attach_target_tokens(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
  254. # Convert `raw_tgt_text` to (full) target tokenized sequences:
  255. # <eos> <lang_tok> <tokens .. > <eos>
  256. # Lang tokens change between rows, so can't use static encoder
  257. builder.map(
  258. [self.spm_encoder],
  259. selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
  260. num_parallel_calls=self.config.num_threads,
  261. )
  262. # Convert the `tgt_text` field into a unit tensor + EOS
  263. # TODO: We should use unit tokenizer.
  264. # Motivation for the current implementation:
  265. # 1) lang_tok can change between rows.
  266. # If we want to attach lang_token_id here, we need a way to join values from two columns
  267. # 2) StrToTensorConverter doesn't allow suffix tokens. Adding it later is less covenient.
  268. # 3) Not a computational blocker
  269. convert_to_units = lambda units_str: ( # noqa: E731
  270. torch.LongTensor(
  271. [int(unit_id) + 4 for unit_id in units_str.rstrip().bytes().decode("utf-8").split()]
  272. + [self.unit_tokenizer.vocab_info.eos_idx]
  273. )
  274. )
  275. builder.map(
  276. [convert_to_units],
  277. selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
  278. num_parallel_calls=self.config.num_threads,
  279. )
  280. # prefixes for tokenized texts and speech units (<eos> <lang_tok>)
  281. prefix_builder = lambda lang_tok: torch.LongTensor( # noqa: E731
  282. [
  283. self.text_prefix_tokens[lang_tok.bytes().decode("utf8")],
  284. self.unit_prefix_tokens[lang_tok.bytes().decode("utf8")],
  285. ]
  286. )
  287. builder.map(
  288. [prefix_builder],
  289. selector=f"{self.ROOT_COLUMN}.{self.TARGET_LANG_COLUMN}",
  290. num_parallel_calls=self.config.num_threads,
  291. )
  292. return builder
  293. def _get_input_audio_seconds(self, sample: Any) -> float:
  294. audio_data = sample[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]
  295. input_audio_sample_rate = audio_data["sample_rate"]
  296. num_fbanks = max(audio_data["fbank"].shape) # not guessing the dim order
  297. # TODO: clarify where '* 2' comes from
  298. waveform_length = num_fbanks * self.config.audio.fbanks_num_mel_bins * 2
  299. input_audio_seconds = waveform_length / input_audio_sample_rate
  300. return input_audio_seconds
  301. def _is_long_sample(self, sample: Any) -> bool:
  302. # input audio length
  303. if self._get_input_audio_seconds(sample) > self.config.max_seconds_per_input_audio:
  304. return True
  305. # target text tokens
  306. num_tgt_text_tokens = sample[self.ROOT_COLUMN][self.TARGET_TEXT_COLUMN].shape[-1]
  307. if num_tgt_text_tokens > self.config.max_tgt_text_tokens_per_sample:
  308. return True
  309. # target units
  310. num_tgt_units = sample[self.ROOT_COLUMN][self.TARGET_UNITS_COLUMN].shape[-1] # target units
  311. if num_tgt_units > self.config.max_units_per_sample:
  312. return True
  313. return False
  314. def _filter_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
  315. # Drop long samples
  316. builder.filter(lambda sample: not self._is_long_sample(sample))
  317. return builder
  318. def _batch_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
  319. if self.config.fixed_batch_size is not None:
  320. builder.bucket(bucket_size=self.config.fixed_batch_size)
  321. elif self.tgt_text_batch_shapes is not None:
  322. builder.bucket_by_length(
  323. self.tgt_text_batch_shapes,
  324. selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
  325. )
  326. else:
  327. raise ValueError("Unclear batching strategy")
  328. # Collate bucketed elements into a batch.
  329. collater = Collater(
  330. pad_to_multiple=1,
  331. overrides=[
  332. CollateOptionsOverride(
  333. selector=f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}.data.fbank",
  334. pad_idx=self.config.fbank_feats_pad_idx,
  335. ),
  336. CollateOptionsOverride(
  337. selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
  338. pad_idx=self.text_tokenizer.vocab_info.pad_idx,
  339. ),
  340. CollateOptionsOverride(
  341. selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
  342. pad_idx=self.unit_tokenizer.vocab_info.pad_idx,
  343. ),
  344. ],
  345. )
  346. builder.map(collater, num_parallel_calls=self.config.num_threads)
  347. if self.config.prefech_batches is not None:
  348. builder.prefetch(self.config.prefech_batches)
  349. return builder
  350. def _build_pipeline(self) -> DataPipeline:
  351. data = self._get_manifest_funnel()
  352. data = self._attach_audio(data)
  353. data = self._attach_target_tokens(data)
  354. data = self._filter_samples(data)
  355. batches = self._batch_samples(data)
  356. return batches.and_return()
  357. def _gen_prev_toks_target_toks_target_lens(
  358. self, seqs: Any, prefix_tokens: torch.Tensor, pad_idx: int, eos_idx: int
  359. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  360. # <eos> <lang_tok> ... <eos> <pad>*
  361. tokens = torch.cat((prefix_tokens, seqs["seqs"]), 1)
  362. target_lengths = seqs["seq_lens"] + 1 # + <leng_tok>
  363. prev_output_tokens = torch.clone(tokens)
  364. # replace last <eos> with <pad> and remove last column
  365. mask = prev_output_tokens == eos_idx
  366. mask[:, 0] = 0
  367. prev_output_tokens[mask] = pad_idx
  368. prev_output_tokens = prev_output_tokens[:, :-1]
  369. target_tokens = tokens[:, 1:]
  370. assert torch.equal(torch.count_nonzero(prev_output_tokens != pad_idx, dim=1), target_lengths)
  371. assert torch.equal(torch.count_nonzero(target_tokens != pad_idx, dim=1), target_lengths)
  372. return prev_output_tokens, target_tokens, target_lengths
  373. def _get_text_to_units_batch(self, raw_batch: Any) -> SeqsBatch:
  374. root = raw_batch[self.ROOT_COLUMN]
  375. seqs = root[self.TARGET_UNITS_COLUMN]
  376. prefix_tokens = root[self.TARGET_LANG_COLUMN][:, 1, :]
  377. pad_idx = self.unit_tokenizer.vocab_info.pad_idx
  378. eos_idx = self.unit_tokenizer.vocab_info.eos_idx
  379. assert pad_idx is not None
  380. assert eos_idx is not None
  381. (
  382. prev_output_tokens,
  383. target_tokens,
  384. target_lengths,
  385. ) = self._gen_prev_toks_target_toks_target_lens(
  386. seqs=seqs,
  387. prefix_tokens=prefix_tokens,
  388. pad_idx=pad_idx,
  389. eos_idx=eos_idx,
  390. )
  391. return SeqsBatch(
  392. src_tokens=None,
  393. src_lengths=None,
  394. target_tokens=target_tokens.to(self.target_device),
  395. prev_output_tokens=prev_output_tokens.to(self.target_device),
  396. target_lengths=target_lengths.to(self.target_device),
  397. prefix_tokens=prefix_tokens.to(self.target_device),
  398. )
  399. def _get_speech_src_tokens_and_lengths(self, raw_batch: Any) -> Tuple[torch.Tensor, torch.Tensor]:
  400. fbanks = raw_batch[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
  401. return fbanks["seqs"].to(self.float_dtype), fbanks["seq_lens"]
  402. def _get_speech_to_text_batch(self, raw_batch: Any) -> SeqsBatch:
  403. root = raw_batch[self.ROOT_COLUMN]
  404. seqs = root[self.TARGET_TEXT_COLUMN]
  405. prefix_tokens = root[self.TARGET_LANG_COLUMN][:, 0, :]
  406. pad_idx = self.text_tokenizer.vocab_info.pad_idx
  407. assert pad_idx is not None
  408. eos_idx = self.text_tokenizer.vocab_info.eos_idx
  409. assert eos_idx is not None
  410. (
  411. prev_output_tokens,
  412. target_tokens,
  413. target_lengths,
  414. ) = self._gen_prev_toks_target_toks_target_lens(
  415. seqs=seqs,
  416. prefix_tokens=prefix_tokens,
  417. pad_idx=pad_idx,
  418. eos_idx=eos_idx,
  419. )
  420. src_tokens, src_lengths = self._get_speech_src_tokens_and_lengths(raw_batch=raw_batch)
  421. return SeqsBatch(
  422. src_tokens=src_tokens.to(self.target_device),
  423. src_lengths=src_lengths.to(self.target_device),
  424. target_tokens=target_tokens.to(self.target_device),
  425. prev_output_tokens=prev_output_tokens.to(self.target_device),
  426. target_lengths=target_lengths.to(self.target_device),
  427. prefix_tokens=prefix_tokens.to(self.target_device),
  428. )
  429. def _convert_to_mulitmodal_seqs_batch(self, raw_batch: Any) -> MultimodalSeqsBatch:
  430. return MultimodalSeqsBatch(
  431. speech_to_text=self._get_speech_to_text_batch(raw_batch=raw_batch),
  432. text_to_units=self._get_text_to_units_batch(raw_batch=raw_batch),
  433. )
  434. def iterate_batches(self) -> Iterator[MultimodalSeqsBatch]:
  435. for raw_batch in self.pipeline:
  436. yield self._convert_to_mulitmodal_seqs_batch(raw_batch)
  437. def reset(self) -> None:
  438. self.pipeline.reset()
  439. if __name__ == "__main__":
  440. logging.basicConfig(
  441. level=logging.INFO,
  442. format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
  443. )
  444. config = DataLoadingConfig(
  445. audio=AudioProcessingConfig(
  446. audio_root_dir="/fsx-ust/data/audio_zips/",
  447. ),
  448. manifest_path_prefix="/fsx-ust/spopuri/datasets/S2ST/V1/M4T_V1_phase2/primary",
  449. manifest_list_path="/data/home/mavlyutov/train_manifests.txt",
  450. shuffle_window=1000,
  451. num_threads=5,
  452. )
  453. loader = UnityDataLoader(config=config, target_device=torch.device("cpu"))
  454. for idx, batch in enumerate(loader.iterate_batches()):
  455. if idx % 10 == 0:
  456. assert batch.speech_to_text.src_tokens is not None
  457. print(batch.speech_to_text.src_tokens.shape)
  458. logger.info(f".. pulled {idx} batches")
  459. if idx > 1000:
  460. break