dataloader.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import json
  7. import logging
  8. from dataclasses import dataclass
  9. from typing import Any, Dict, Iterable, List, Optional
  10. import numpy as np
  11. import torch
  12. import torchaudio
  13. import torchaudio.compliance.kaldi as ta_kaldi
  14. from datasets import Dataset
  15. from datasets.distributed import split_dataset_by_node
  16. from fairseq2.models.nllb.tokenizer import NllbTokenizer, TextTokenEncoder
  17. from torch import Tensor
  18. from torch.nn.functional import pad as pad_tensor
  19. from torch.utils.data import DataLoader
  20. from seamless_communication.datasets.datatypes import LangPairSample
  21. from seamless_communication.models.unity.unit_tokenizer import (
  22. UnitTokenEncoder,
  23. UnitTokenizer,
  24. )
  25. logger = logging.getLogger(__name__)
  26. @dataclass
  27. class SeqsBatch:
  28. src_tokens: Optional[Tensor]
  29. src_lengths: Optional[Tensor]
  30. target_tokens: Optional[Tensor]
  31. prev_output_tokens: Optional[Tensor]
  32. target_lengths: Optional[Tensor]
  33. def __del__(self) -> None:
  34. """Explicitly delete tensors
  35. to force GPU memory cleanup"""
  36. for tensor in [
  37. self.src_tokens,
  38. self.src_lengths,
  39. self.target_tokens,
  40. self.prev_output_tokens,
  41. self.target_lengths,
  42. ]:
  43. if tensor is not None:
  44. del tensor
  45. @dataclass
  46. class MultimodalSeqsBatch:
  47. speech_to_text: SeqsBatch
  48. text_to_units: SeqsBatch
  49. def __del__(self) -> None:
  50. del self.speech_to_text
  51. del self.text_to_units
  52. @dataclass
  53. class BatchingConfig:
  54. fbank_feats_pad_idx: int = 0
  55. """The pad index to use in fbanks batching."""
  56. batch_size: int = 5
  57. rank: int = 0
  58. """The rank of this worker in the process group."""
  59. world_size: int = 1
  60. """The world size of the process group."""
  61. num_workers: int = 2
  62. """Parallelism in dataset preparation."""
  63. float_dtype: torch.dtype = torch.float16
  64. """Select between fp16/fp32 for float tensors """
  65. def worker_init_fn(worker_id):
  66. np.random.seed(np.random.get_state()[1][0] + worker_id)
  67. class UnitYDataLoader:
  68. def __init__(
  69. self,
  70. text_tokenizer: NllbTokenizer,
  71. unit_tokenizer: UnitTokenizer,
  72. dataset_manifest_path: str,
  73. batching_config: BatchingConfig,
  74. ):
  75. self.text_tokenizer = text_tokenizer
  76. self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
  77. self.unit_tokenizer = unit_tokenizer
  78. self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
  79. self.batching_config = batching_config
  80. self.dataset = self._load_manifest(dataset_manifest_path)
  81. def get_dataloader(self) -> DataLoader:
  82. subset = split_dataset_by_node(
  83. self.dataset,
  84. rank=self.batching_config.rank,
  85. world_size=self.batching_config.world_size,
  86. )
  87. data_loader = DataLoader(
  88. dataset=subset,
  89. batch_size=self.batching_config.batch_size,
  90. shuffle=True,
  91. num_workers=self.batching_config.num_workers,
  92. collate_fn=self._prepare_batch,
  93. worker_init_fn=worker_init_fn,
  94. )
  95. return data_loader
  96. def __iter__(self) -> Iterable[MultimodalSeqsBatch]:
  97. return self.get_dataloader().__iter__()
  98. def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
  99. audio_input = torchaudio.load(sample.source.audio_local_path)[0]
  100. return ta_kaldi.fbank(audio_input, num_mel_bins=80)
  101. def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor:
  102. """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
  103. target_lang = sample.target.lang
  104. if target_lang not in self.text_encoders_per_lang:
  105. self.text_encoders_per_lang[
  106. target_lang
  107. ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
  108. tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
  109. eos_idx = self.text_tokenizer.vocab_info.eos_idx
  110. tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
  111. return tokens
  112. def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]:
  113. """Expected sequence is [<eos>, <lang_tok> , ..unit tokens.., <eos>]"""
  114. if sample.target.units is None:
  115. return None
  116. target_lang = sample.target.lang
  117. if target_lang not in self.unit_encoders_per_lang:
  118. self.unit_encoders_per_lang[
  119. target_lang
  120. ] = self.unit_tokenizer.create_encoder(lang=target_lang)
  121. tokens = self.unit_encoders_per_lang[target_lang](
  122. torch.LongTensor(sample.target.units).unsqueeze(0)
  123. )
  124. eos_idx = self.unit_tokenizer.vocab_info.eos_idx
  125. tokens = torch.concat([tokens.squeeze(0), torch.LongTensor([eos_idx])])
  126. return tokens
  127. def _batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
  128. padding_size = max(tensor.shape[0] for tensor in tensors)
  129. dims = len(tensors[0].shape)
  130. padded_tensors = []
  131. for tensor in tensors:
  132. padding = [0] * 2 * dims
  133. padding[-1] = padding_size - tensor.shape[0]
  134. padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
  135. return torch.stack([tensor for tensor in padded_tensors], dim=0)
  136. def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
  137. samples = [LangPairSample.from_json(sample) for sample in raw_samples]
  138. # input speech
  139. src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
  140. src_tokens = self._batch_tensors(
  141. src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
  142. ).to(self.batching_config.float_dtype)
  143. src_lengths = torch.LongTensor(
  144. [src_tokens.shape[0] for src_tokens in src_tokens_list]
  145. )
  146. # output text
  147. text_tokens_list = [
  148. self._get_tokenized_target_text(sample) for sample in samples
  149. ]
  150. text_pad_idx = self.text_tokenizer.vocab_info.pad_idx
  151. prev_outputs_tokens = self._batch_tensors(
  152. [tokens[:-1] for tokens in text_tokens_list], pad_value=text_pad_idx
  153. )
  154. target_tokens = self._batch_tensors(
  155. [tokens[1:] for tokens in text_tokens_list], pad_value=text_pad_idx
  156. )
  157. tokens_lengths = torch.LongTensor(
  158. [tokens.shape[0] - 1 for tokens in text_tokens_list]
  159. )
  160. # output units
  161. units_list_raw = [self._get_tokenized_units(sample) for sample in samples]
  162. if None in units_list_raw:
  163. prev_outputs_units = None
  164. target_units = None
  165. units_lengths = None
  166. else:
  167. units_list: List[Tensor] = [
  168. value for value in units_list_raw if value is not None
  169. ]
  170. units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
  171. prev_outputs_units = self._batch_tensors(
  172. [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
  173. )
  174. target_units = self._batch_tensors(
  175. [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
  176. )
  177. units_lengths = torch.LongTensor(
  178. [tokens.shape[0] - 1 for tokens in units_list]
  179. )
  180. return MultimodalSeqsBatch(
  181. speech_to_text=SeqsBatch(
  182. src_tokens=src_tokens,
  183. src_lengths=src_lengths,
  184. target_tokens=target_tokens,
  185. prev_output_tokens=prev_outputs_tokens,
  186. target_lengths=tokens_lengths,
  187. ),
  188. text_to_units=SeqsBatch(
  189. src_tokens=None,
  190. src_lengths=None,
  191. target_tokens=target_units,
  192. prev_output_tokens=prev_outputs_units,
  193. target_lengths=units_lengths,
  194. ),
  195. )
  196. def _load_manifest(self, dataset_manifest_path: str) -> Dataset:
  197. with open(dataset_manifest_path) as fp_in:
  198. dataset = [json.loads(line) for line in fp_in]
  199. return Dataset.from_list(dataset)