dataset.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import os
  2. import math
  3. import json
  4. import numpy as np
  5. import torch
  6. from typing import List, Union
  7. from abc import ABC, abstractmethod
  8. from scipy.linalg import block_diag
  9. from itertools import accumulate
  10. from bisect import bisect_right
  11. from SwissArmyTransformer import get_tokenizer
  12. from SwissArmyTransformer.mpu import get_model_parallel_rank
  13. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
  14. from .utils import get_tokenized_input
  15. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  16. attention_mask = np.pad(
  17. attention_mask,
  18. pad_width=((0, max_seq_length - len(tokens)),),
  19. mode="constant",
  20. constant_values=0,
  21. )
  22. tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
  23. position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
  24. return tokens, position_ids, attention_mask
  25. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  26. """
  27. Jsonlines of {
  28. "text": context
  29. "choices": [choice_id1,...], if not None, len(target) == 1
  30. "label": If generation task -1, else [0, len(choices))
  31. }
  32. If [MASK] not in context, will append [MASK] after text
  33. """
  34. def __init__(self, path: Union[str, List[str]], config: BaseConfig):
  35. self.path = path if isinstance(path, list) else [path]
  36. self.config = config
  37. self.max_seq_length = self.config.max_seq_length
  38. self.dtype = np.int64
  39. self.tokenizer = get_tokenizer()
  40. self.mask_id = self.tokenizer.get_command("[MASK]")
  41. self.gmask_id = self.tokenizer.get_command("[gMASK]")
  42. self.data = []
  43. for p in self.path:
  44. self.process_single_file(p)
  45. @property
  46. def has_collate_fn(self) -> bool:
  47. return False
  48. def collate_fn(self, samples):
  49. return None
  50. def process_single_file(self, path):
  51. with open(os.path.join(path), "r", encoding="utf-8") as file:
  52. for line in file:
  53. item = json.loads(line)
  54. self.data.append(self.process_single_item(item))
  55. @abstractmethod
  56. def process_single_item(self, item) -> dict:
  57. pass
  58. def __len__(self):
  59. return len(self.data)
  60. class GenerationTaskDataset(EvaluationDataset):
  61. config: GenerationTaskConfig
  62. def process_single_item(self, item):
  63. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  64. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  65. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  66. text = text[len(text) - text_length : len(text)]
  67. return {"text": text, "targets": targets}
  68. @property
  69. def has_collate_fn(self) -> bool:
  70. return True
  71. def collate_fn(self, samples):
  72. TILE = 32
  73. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  74. token_batch, position_id_batch, attention_mask_batch = [], [], []
  75. context_length_batch, target_position_id_batch = [], []
  76. for sample in samples:
  77. token, position_id, attention_mask = pad_batch(
  78. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  79. )
  80. token_batch.append(token)
  81. position_id_batch.append(position_id)
  82. attention_mask_batch.append(attention_mask)
  83. context_length_batch.append(sample["context_length"])
  84. target_position_id_batch.append(sample["target_position_id"])
  85. return {
  86. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  87. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  88. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  89. "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
  90. "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
  91. }
  92. @staticmethod
  93. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  94. tokenizer = get_tokenizer()
  95. sop_id = tokenizer.get_command("sop")
  96. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  97. token = np.array(text, dtype=np.int64)
  98. blank_filling = mask_id in text
  99. if blank_filling:
  100. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  101. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  102. mask_position = text.index(mask_id)
  103. token = np.concatenate((token, [sop_id]))
  104. else:
  105. mask_position = len(token)
  106. if unidirectional:
  107. token = np.concatenate(([mask_id, sop_id], token))
  108. else:
  109. token = np.concatenate((token, [mask_id, sop_id]))
  110. context_length = len(token)
  111. position_id = np.arange(0, context_length, dtype=np.int64)
  112. target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
  113. if not use_task_mask:
  114. position_id[context_length - 1 :] = mask_position
  115. target_position_id[:] = mask_position
  116. attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
  117. if not unidirectional:
  118. attention_mask[: context_length - 1, : context_length - 1] = 1
  119. item = {
  120. "token": token,
  121. "position_id": position_id,
  122. "target_position_id": target_position_id,
  123. "attention_mask": attention_mask,
  124. "context_length": context_length,
  125. }
  126. return item
  127. def __getitem__(self, idx):
  128. item = self.data[idx]
  129. sample = self.build_generation_sample(
  130. item["text"],
  131. max_gen_length=self.config.max_gen_length,
  132. use_task_mask=self.config.use_task_mask,
  133. unidirectional=self.config.unidirectional,
  134. )
  135. sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
  136. return sample
  137. class MultiChoiceTaskDataset(EvaluationDataset):
  138. config: MultiChoiceTaskConfig
  139. def __init__(self, path, config: MultiChoiceTaskConfig):
  140. self.is_single_token = True # set to False later in process_single_item func
  141. super().__init__(path, config)
  142. @property
  143. def has_collate_fn(self) -> bool:
  144. return True
  145. def collate_fn(self, samples):
  146. TILE = 32
  147. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  148. token_batch, position_id_batch, attention_mask_batch = [], [], []
  149. choices_batch, choice_target_ids_batch = [], []
  150. for sample in samples:
  151. token, position_id, attention_mask = pad_batch(
  152. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  153. )
  154. token_batch.append(token)
  155. position_id_batch.append(position_id)
  156. attention_mask_batch.append(attention_mask)
  157. choices_batch.append(sample["choices"])
  158. choice_target_ids_batch.append(sample["choice_target_ids"])
  159. return {
  160. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  161. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  162. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  163. "choices": choices_batch,
  164. "choice_target_ids": choice_target_ids_batch,
  165. "is_single_token": self.is_single_token,
  166. }
  167. def process_single_item(self, item):
  168. text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
  169. tgt_seq_length = sum([len(choice) for choice in choices])
  170. if tgt_seq_length == len(choices):
  171. # For single token, we only insert one [sop]
  172. tgt_seq_length = 1
  173. assert tgt_seq_length < self.config.max_seq_length
  174. if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
  175. text_length = self.config.max_seq_length - tgt_seq_length - 2
  176. text = text[len(text) - text_length : len(text)]
  177. assert not (
  178. self.mask_id in text and self.config.use_multitask_encoding
  179. ), "Unified multitask encoding don't support blank filling"
  180. if tgt_seq_length != 1:
  181. self.is_single_token = False
  182. return {
  183. "text": text,
  184. "choices": choices,
  185. "label": label,
  186. }
  187. @staticmethod
  188. def build_multiple_choice_sample(
  189. text, choices, is_single_token, unified_multitask_encoding=False, use_task_mask=False
  190. ):
  191. tokenizer = get_tokenizer()
  192. sop_id = tokenizer.get_command("sop")
  193. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  194. token = np.array(text, dtype=np.int64)
  195. target = np.array(text, dtype=np.int64)
  196. position_id = np.arange(len(text), dtype=np.int64)
  197. choice_target_id = []
  198. blank_filling = mask_id in text
  199. if not blank_filling:
  200. mask_position = len(token)
  201. token = np.concatenate((token, [mask_id]))
  202. target = np.concatenate((target, [mask_id]))
  203. position_id = np.concatenate((position_id, [mask_position]))
  204. else:
  205. mask_position = text.index(mask_id)
  206. division = len(token)
  207. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  208. for choice in choices:
  209. if len(choice) == 0:
  210. if get_model_parallel_rank() == 0:
  211. print("Empty choice found")
  212. choice = [0]
  213. if use_task_mask == False:
  214. position_id = np.concatenate(
  215. (
  216. position_id,
  217. [mask_position] * len(choice)
  218. if blank_filling or not unified_multitask_encoding
  219. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  220. )
  221. )
  222. else:
  223. position_id = np.concatenate(
  224. (
  225. position_id,
  226. np.arange(division, division + len(choice), dtype=np.int64),
  227. )
  228. )
  229. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  230. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  231. token = np.concatenate((token, [sop_id], choice[:-1]))
  232. target = np.concatenate((target, choice))
  233. if is_single_token:
  234. break
  235. attention_mask = block_diag(*attention_mask)
  236. attention_mask[: len(token), :division] = 1
  237. if is_single_token:
  238. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  239. item = {
  240. "token": token,
  241. "position_id": position_id,
  242. "attention_mask": attention_mask,
  243. "choices": choices,
  244. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  245. }
  246. return item
  247. def __getitem__(self, idx):
  248. item = self.data[idx]
  249. sample = self.build_multiple_choice_sample(
  250. item["text"],
  251. item["choices"],
  252. is_single_token=self.is_single_token,
  253. unified_multitask_encoding=self.config.use_multitask_encoding,
  254. use_task_mask=self.config.use_task_mask,
  255. )
  256. sample["label"] = item["label"]
  257. return sample
  258. class LanguageModelTaskDataset(EvaluationDataset):
  259. config: LanguageModelTaskConfig
  260. left_weights: List[int]
  261. weights: List[int]
  262. def process_single_file(self, path):
  263. num_sequences = []
  264. with open(os.path.join(path), "r", encoding="utf-8") as file:
  265. raw_text = file.read()
  266. tokens = self.tokenizer.tokenize(raw_text)
  267. self.data.append(
  268. {
  269. "raw_text": tokens,
  270. "num_original_tokens": len(raw_text.strip().split(" ")),
  271. "num_sequences": max(
  272. math.ceil(
  273. max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
  274. )
  275. + 1,
  276. 1,
  277. ),
  278. }
  279. )
  280. num_sequences.append(self.data[-1]["num_sequences"])
  281. self.weights = list(accumulate(num_sequences))
  282. self.left_weights = [0] + self.weights[:-1]
  283. def process_single_item(self, item):
  284. pass
  285. def __len__(self):
  286. return self.data[0]["num_sequences"]
  287. def __getitem__(self, idx):
  288. document_idx = bisect_right(self.weights, idx)
  289. idx = idx - self.left_weights[document_idx]
  290. start_idx = idx * self.config.generation_length
  291. end_idx = start_idx + self.config.max_seq_length - 1 # for additional [gMASK]
  292. tokens = self.data[document_idx]["raw_text"][start_idx:end_idx]
  293. mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
  294. sop_id = self.tokenizer.get_command("sop")
  295. if idx == 0 or self.config.unidirectional:
  296. prompt, text = [], tokens
  297. else:
  298. prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
  299. prompt, text = tokens[:prompt_length], tokens[prompt_length:]
  300. seq_length = len(prompt) + len(text) + 1
  301. attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
  302. attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
  303. return {
  304. "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
  305. "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
  306. "position_ids": np.arange(0, seq_length, dtype=np.int64),
  307. "attention_mask": attention_mask < 0.5,
  308. "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
  309. }