dataset.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. import os
  2. import math
  3. import json
  4. import numpy as np
  5. import torch
  6. from typing import List, Union
  7. from abc import ABC, abstractmethod
  8. from scipy.linalg import block_diag
  9. from itertools import accumulate
  10. from bisect import bisect_right
  11. from SwissArmyTransformer import get_tokenizer
  12. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
  13. from .utils import get_tokenized_input
  14. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  15. pad_length = max_seq_length - len(tokens)
  16. attention_mask = np.pad(
  17. attention_mask,
  18. pad_width=((0, pad_length),),
  19. mode="constant",
  20. constant_values=0,
  21. )
  22. tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))
  23. position_ids = np.concatenate((position_ids, position_ids[..., -1:].repeat(pad_length, -1)), axis=-1)
  24. return tokens, position_ids, attention_mask
  25. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  26. """
  27. Jsonlines of {
  28. "text": context
  29. "choices": [choice_id1,...], if not None, len(target) == 1
  30. "label": If generation task -1, else [0, len(choices))
  31. }
  32. If [MASK] not in context, will append [MASK] after text
  33. """
  34. def __init__(self, path: Union[str, List[str]], config: BaseConfig):
  35. self.path = path if isinstance(path, list) else [path]
  36. self.config = config
  37. self.max_seq_length = self.config.max_seq_length
  38. self.dtype = np.int64
  39. self.tokenizer = get_tokenizer()
  40. self.mask_id = self.tokenizer.get_command("[MASK]")
  41. self.gmask_id = self.tokenizer.get_command("[gMASK]")
  42. self.data = []
  43. for p in self.path:
  44. self.process_single_file(p)
  45. @property
  46. def has_collate_fn(self) -> bool:
  47. return False
  48. @staticmethod
  49. def collate_fn(self, samples):
  50. return None
  51. def process_single_file(self, path):
  52. with open(os.path.join(path), "r", encoding="utf-8") as file:
  53. for line in file:
  54. item = json.loads(line)
  55. self.data.append(self.process_single_item(item))
  56. @abstractmethod
  57. def process_single_item(self, item) -> dict:
  58. pass
  59. def __len__(self):
  60. return len(self.data)
  61. class GenerationTaskDataset(EvaluationDataset):
  62. config: GenerationTaskConfig
  63. def process_single_item(self, item):
  64. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  65. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  66. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  67. text = text[len(text) - text_length : len(text)]
  68. return {"text": text, "targets": targets}
  69. @property
  70. def has_collate_fn(self) -> bool:
  71. return True
  72. @staticmethod
  73. def collate_fn(samples):
  74. TILE = 32
  75. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  76. token_batch, position_id_batch, attention_mask_batch = [], [], []
  77. context_length_batch, target_position_id_batch = [], []
  78. for sample in samples:
  79. token, position_id, attention_mask = pad_batch(
  80. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  81. )
  82. token_batch.append(token)
  83. position_id_batch.append(position_id)
  84. attention_mask_batch.append(attention_mask)
  85. context_length_batch.append(sample["context_length"])
  86. target_position_id_batch.append(sample["target_position_id"])
  87. return {
  88. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  89. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  90. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  91. "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
  92. "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
  93. }
  94. @staticmethod
  95. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  96. tokenizer = get_tokenizer()
  97. sop_id = tokenizer.get_command("sop")
  98. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  99. token = np.array(text, dtype=np.int64)
  100. blank_filling = mask_id in text
  101. if blank_filling:
  102. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  103. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  104. mask_position = text.index(mask_id)
  105. token = np.concatenate((token, [sop_id]))
  106. else:
  107. mask_position = len(token)
  108. if unidirectional:
  109. token = np.concatenate(([mask_id, sop_id], token))
  110. else:
  111. token = np.concatenate((token, [mask_id, sop_id]))
  112. context_length = len(token)
  113. position_id = np.arange(0, context_length, dtype=np.int64)
  114. target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
  115. if not use_task_mask:
  116. position_id[context_length - 1 :] = mask_position
  117. target_position_id[:] = mask_position
  118. attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
  119. if not unidirectional:
  120. attention_mask[: context_length - 1, : context_length - 1] = 1
  121. item = {
  122. "token": token,
  123. "position_id": position_id,
  124. "target_position_id": target_position_id,
  125. "attention_mask": attention_mask,
  126. "context_length": context_length,
  127. }
  128. return item
  129. def __getitem__(self, idx):
  130. item = self.data[idx]
  131. return self.build_generation_sample(
  132. item["text"],
  133. max_gen_length=self.config.max_gen_length,
  134. use_task_mask=self.config.use_task_mask,
  135. unidirectional=self.config.unidirectional,
  136. )
  137. class SmallGenerationTaskDataset(GenerationTaskDataset):
  138. def process_single_item(self, item):
  139. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  140. if len(text) + self.config.max_gen_length + 3 > self.config.max_seq_length:
  141. text_length = self.config.max_seq_length - self.config.max_gen_length - 3
  142. text = text[len(text) - text_length : len(text)]
  143. return {"text": text, "targets": targets}
  144. @staticmethod
  145. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  146. tokenizer = get_tokenizer()
  147. sop_id = tokenizer.get_command("sop")
  148. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  149. cls_id = tokenizer.get_command("ENC")
  150. eos_id = tokenizer.get_command("eos")
  151. token = np.array(text, dtype=np.int64)
  152. blank_filling = mask_id in text
  153. if blank_filling:
  154. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  155. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  156. mask_position = text.index(mask_id) + 1
  157. context_length = len(token) + 2
  158. token = np.concatenate(([cls_id], token, [eos_id, sop_id]))
  159. else:
  160. if unidirectional:
  161. mask_position = 1
  162. context_length = 3
  163. token = np.concatenate(([cls_id, mask_id, eos_id, sop_id], token))
  164. else:
  165. mask_position = len(token) + 1
  166. context_length = len(token) + 3
  167. token = np.concatenate(([cls_id], token, [mask_id, eos_id, sop_id]))
  168. prefix_length = len(token) - context_length
  169. position_id = [list(range(context_length)) + [mask_position] * prefix_length,
  170. [0] * context_length + list(range(1, prefix_length + 1))]
  171. position_id = np.array(position_id, dtype=np.int64)
  172. target_position_id = [[mask_position] * max_gen_length,
  173. list(range(prefix_length + 1, prefix_length + max_gen_length + 1))]
  174. target_position_id = np.array(target_position_id, dtype=np.int64)
  175. attention_mask = np.tril(np.ones((len(token), len(token)), dtype=np.int64))
  176. if not unidirectional:
  177. attention_mask[: len(token) - 1, : len(token) - 1] = 1
  178. item = {
  179. "token": token,
  180. "position_id": position_id,
  181. "target_position_id": target_position_id,
  182. "attention_mask": attention_mask,
  183. "context_length": context_length,
  184. }
  185. return item
  186. class MultiChoiceTaskDataset(EvaluationDataset):
  187. config: MultiChoiceTaskConfig
  188. def __init__(self, path, config: MultiChoiceTaskConfig):
  189. self.is_single_token = True # set to False later in process_single_item func
  190. super().__init__(path, config)
  191. @property
  192. def has_collate_fn(self) -> bool:
  193. return True
  194. @staticmethod
  195. def num_special_tokens():
  196. return 2
  197. @staticmethod
  198. def collate_fn(samples):
  199. TILE = 32
  200. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  201. token_batch, position_id_batch, attention_mask_batch = [], [], []
  202. choices_batch, choice_target_ids_batch = [], []
  203. is_single_token = True
  204. for sample in samples:
  205. token, position_id, attention_mask = pad_batch(
  206. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  207. )
  208. token_batch.append(token)
  209. position_id_batch.append(position_id)
  210. attention_mask_batch.append(attention_mask)
  211. choices_batch.append(sample["choices"])
  212. choice_target_ids_batch.append(sample["choice_target_ids"])
  213. if isinstance(sample["choice_target_ids"], list):
  214. is_single_token = False
  215. return {
  216. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  217. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  218. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  219. "choices": choices_batch,
  220. "choice_target_ids": choice_target_ids_batch,
  221. "is_single_token": is_single_token,
  222. }
  223. def process_single_item(self, item):
  224. text = get_tokenized_input(item, "inputs", no_tokenized=self.config.no_tokenized)
  225. choices = get_tokenized_input(item, "choices", no_tokenized=self.config.no_tokenized)
  226. label = item["label"]
  227. tgt_seq_length = sum([len(choice) for choice in choices])
  228. if tgt_seq_length == len(choices):
  229. # For single token, we only insert one [sop]
  230. tgt_seq_length = 1
  231. assert tgt_seq_length < self.config.max_seq_length
  232. if len(text) + tgt_seq_length + self.num_special_tokens() > self.config.max_seq_length:
  233. text_length = self.config.max_seq_length - tgt_seq_length - self.num_special_tokens()
  234. text = text[len(text) - text_length: len(text)]
  235. assert not (
  236. self.mask_id in text and self.config.use_multitask_encoding
  237. ), "Unified multitask encoding don't support blank filling"
  238. if tgt_seq_length != 1:
  239. self.is_single_token = False
  240. return {
  241. "text": text,
  242. "choices": choices,
  243. "label": label,
  244. }
  245. @staticmethod
  246. def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
  247. tokenizer = get_tokenizer()
  248. sop_id = tokenizer.get_command("sop")
  249. mask_id = tokenizer.get_command("[MASK]")
  250. token = np.array(text, dtype=np.int64)
  251. target = np.array(text, dtype=np.int64)
  252. position_id = np.arange(len(text), dtype=np.int64)
  253. choice_target_id = []
  254. blank_filling = mask_id in text
  255. if not blank_filling:
  256. mask_position = len(token)
  257. token = np.concatenate((token, [mask_id]))
  258. target = np.concatenate((target, [mask_id]))
  259. position_id = np.concatenate((position_id, [mask_position]))
  260. else:
  261. mask_position = text.index(mask_id)
  262. division = len(token)
  263. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  264. for choice in choices:
  265. position_id = np.concatenate(
  266. (
  267. position_id,
  268. [mask_position] * len(choice)
  269. if blank_filling or not unified_multitask_encoding
  270. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  271. )
  272. )
  273. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  274. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  275. token = np.concatenate((token, [sop_id], choice[:-1]))
  276. target = np.concatenate((target, choice))
  277. if is_single_token:
  278. break
  279. attention_mask = block_diag(*attention_mask)
  280. attention_mask[: len(token), :division] = 1
  281. if is_single_token:
  282. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  283. item = {
  284. "token": token,
  285. "position_id": position_id,
  286. "attention_mask": attention_mask,
  287. "choices": choices,
  288. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  289. }
  290. return item
  291. def __getitem__(self, idx):
  292. item = self.data[idx]
  293. return self.build_multiple_choice_sample(
  294. item["text"],
  295. item["choices"],
  296. is_single_token=self.is_single_token,
  297. unified_multitask_encoding=self.config.use_multitask_encoding,
  298. )
  299. class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
  300. @staticmethod
  301. def num_special_tokens():
  302. return 3
  303. @staticmethod
  304. def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
  305. unidirectional=False):
  306. tokenizer = get_tokenizer()
  307. cls_id = tokenizer.get_command("ENC")
  308. eos_id = tokenizer.get_command("eos")
  309. sop_id = tokenizer.get_command("sop")
  310. mask_id = tokenizer.get_command("[MASK]") if not unidirectional else tokenizer.get_command("[gMASK]")
  311. blank_filling = mask_id in text
  312. text_length = len(text)
  313. last_token = text[-1]
  314. if unidirectional:
  315. assert not blank_filling
  316. text = [cls_id, mask_id, eos_id, sop_id] + text[:-1]
  317. position_id = np.array(list(range(3)) + [1] * text_length, dtype=np.int64)
  318. block_position_id = np.array([0] * 3 + list(range(1, text_length + 1)), dtype=np.int64)
  319. else:
  320. if not blank_filling:
  321. text = text + [mask_id]
  322. text = [cls_id] + text + [eos_id]
  323. position_id = np.arange(len(text), dtype=np.int64)
  324. block_position_id = np.zeros(len(text), dtype=np.int64)
  325. token = np.array(text, dtype=np.int64)
  326. target = np.array(text, dtype=np.int64)
  327. mask_position = text.index(mask_id)
  328. choice_target_id = []
  329. division = len(token)
  330. if unidirectional:
  331. attention_mask = [np.tril(np.ones((len(token), len(token)), dtype=np.int64))]
  332. attention_mask[0][:3, :3] = 1
  333. else:
  334. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  335. for choice in choices:
  336. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  337. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  338. position_id = np.concatenate((position_id, [mask_position] * len(choice)))
  339. if unidirectional:
  340. block_position_id = np.concatenate(
  341. (block_position_id, range(1 + text_length, 1 + text_length + len(choice))))
  342. token = np.concatenate((token, [last_token], choice[:-1]))
  343. else:
  344. block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
  345. token = np.concatenate((token, [sop_id], choice[:-1]))
  346. target = np.concatenate((target, choice))
  347. if is_single_token:
  348. break
  349. attention_mask = block_diag(*attention_mask)
  350. if unidirectional:
  351. attention_mask[division:, :division] = 1
  352. else:
  353. attention_mask[: len(token), :division] = 1
  354. if is_single_token:
  355. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  356. position_id = np.stack((position_id, block_position_id), axis=0)
  357. item = {
  358. "token": token,
  359. "position_id": position_id,
  360. "attention_mask": attention_mask,
  361. "choices": choices,
  362. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  363. }
  364. return item
  365. def __getitem__(self, idx):
  366. item = self.data[idx]
  367. return self.build_multiple_choice_sample(
  368. item["text"],
  369. item["choices"],
  370. is_single_token=self.is_single_token,
  371. unified_multitask_encoding=self.config.use_multitask_encoding,
  372. unidirectional=self.config.unidirectional
  373. )
  374. class LanguageModelTaskDataset(EvaluationDataset):
  375. config: LanguageModelTaskConfig
  376. def process_single_file(self, path):
  377. with open(os.path.join(path), "r", encoding="utf-8") as file:
  378. raw_text = file.read()
  379. tokens = self.tokenizer.tokenize(raw_text)
  380. self.data.append(
  381. {
  382. "raw_text": tokens,
  383. "num_original_tokens": len(raw_text.strip().split(" ")),
  384. "num_sequences": max(
  385. math.ceil(
  386. max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
  387. )
  388. + 1,
  389. 1,
  390. ),
  391. }
  392. )
  393. def process_single_item(self, item):
  394. pass
  395. def __len__(self):
  396. return self.data[0]["num_sequences"]
  397. def __getitem__(self, idx):
  398. start_idx = idx * self.config.generation_length
  399. end_idx = start_idx + self.config.max_seq_length - 1 # for additional [gMASK]
  400. tokens = self.data[0]["raw_text"][start_idx:end_idx]
  401. mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
  402. sop_id = self.tokenizer.get_command("sop")
  403. if idx == 0 or self.config.unidirectional:
  404. prompt, text = tokens[:1], tokens[1:]
  405. else:
  406. prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
  407. prompt, text = tokens[:prompt_length], tokens[prompt_length:]
  408. seq_length = len(prompt) + len(text) + 1
  409. attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
  410. attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
  411. return {
  412. "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
  413. "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
  414. "position_ids": np.arange(0, seq_length, dtype=np.int64),
  415. "attention_mask": attention_mask < 0.5,
  416. "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
  417. }