dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import os
  2. import math
  3. import json
  4. import numpy as np
  5. import torch
  6. from typing import List, Union
  7. from abc import ABC, abstractmethod
  8. from scipy.linalg import block_diag
  9. from itertools import accumulate
  10. from bisect import bisect_right
  11. from SwissArmyTransformer import get_tokenizer
  12. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
  13. from .utils import get_tokenized_input
  14. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  15. pad_length = max_seq_length - len(tokens)
  16. attention_mask = np.pad(
  17. attention_mask,
  18. pad_width=((0, pad_length),),
  19. mode="constant",
  20. constant_values=0,
  21. )
  22. tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))
  23. position_ids = np.concatenate((position_ids, position_ids[..., -1:].repeat(pad_length, -1)), axis=-1)
  24. return tokens, position_ids, attention_mask
  25. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  26. """
  27. Jsonlines of {
  28. "text": context
  29. "choices": [choice_id1,...], if not None, len(target) == 1
  30. "label": If generation task -1, else [0, len(choices))
  31. }
  32. If [MASK] not in context, will append [MASK] after text
  33. """
  34. def __init__(self, path: Union[str, List[str]], config: BaseConfig):
  35. self.path = path if isinstance(path, list) else [path]
  36. self.config = config
  37. self.max_seq_length = self.config.max_seq_length
  38. self.dtype = np.int64
  39. self.tokenizer = get_tokenizer()
  40. self.mask_id = self.tokenizer.get_command("[MASK]")
  41. self.gmask_id = self.tokenizer.get_command("[gMASK]")
  42. self.data = []
  43. for p in self.path:
  44. self.process_single_file(p)
  45. @property
  46. def has_collate_fn(self) -> bool:
  47. return False
  48. @staticmethod
  49. def collate_fn(self, samples):
  50. return None
  51. def process_single_file(self, path):
  52. with open(os.path.join(path), "r", encoding="utf-8") as file:
  53. for line in file:
  54. item = json.loads(line)
  55. self.data.append(self.process_single_item(item))
  56. @abstractmethod
  57. def process_single_item(self, item) -> dict:
  58. pass
  59. def __len__(self):
  60. return len(self.data)
  61. class GenerationTaskDataset(EvaluationDataset):
  62. config: GenerationTaskConfig
  63. def process_single_item(self, item):
  64. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  65. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  66. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  67. text = text[len(text) - text_length : len(text)]
  68. return {"text": text, "targets": targets}
  69. @property
  70. def has_collate_fn(self) -> bool:
  71. return True
  72. @staticmethod
  73. def collate_fn(samples):
  74. TILE = 32
  75. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  76. token_batch, position_id_batch, attention_mask_batch = [], [], []
  77. context_length_batch, target_position_id_batch = [], []
  78. for sample in samples:
  79. token, position_id, attention_mask = pad_batch(
  80. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  81. )
  82. token_batch.append(token)
  83. position_id_batch.append(position_id)
  84. attention_mask_batch.append(attention_mask)
  85. context_length_batch.append(sample["context_length"])
  86. target_position_id_batch.append(sample["target_position_id"])
  87. return {
  88. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  89. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  90. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  91. "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
  92. "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
  93. }
  94. @staticmethod
  95. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  96. tokenizer = get_tokenizer()
  97. sop_id = tokenizer.get_command("sop")
  98. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  99. token = np.array(text, dtype=np.int64)
  100. blank_filling = mask_id in text
  101. if blank_filling:
  102. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  103. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  104. mask_position = text.index(mask_id)
  105. token = np.concatenate((token, [sop_id]))
  106. else:
  107. mask_position = len(token)
  108. if unidirectional:
  109. token = np.concatenate(([mask_id, sop_id], token))
  110. else:
  111. token = np.concatenate((token, [mask_id, sop_id]))
  112. context_length = len(token)
  113. position_id = np.arange(0, context_length, dtype=np.int64)
  114. target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
  115. if not use_task_mask:
  116. position_id[context_length - 1 :] = mask_position
  117. target_position_id[:] = mask_position
  118. attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
  119. if not unidirectional:
  120. attention_mask[: context_length - 1, : context_length - 1] = 1
  121. item = {
  122. "token": token,
  123. "position_id": position_id,
  124. "target_position_id": target_position_id,
  125. "attention_mask": attention_mask,
  126. "context_length": context_length,
  127. }
  128. return item
  129. def __getitem__(self, idx):
  130. item = self.data[idx]
  131. return self.build_generation_sample(
  132. item["text"],
  133. max_gen_length=self.config.max_gen_length,
  134. use_task_mask=self.config.use_task_mask,
  135. unidirectional=self.config.unidirectional,
  136. )
  137. class SmallGenerationTaskDataset(GenerationTaskDataset):
  138. config: GenerationTaskConfig
  139. @staticmethod
  140. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  141. tokenizer = get_tokenizer()
  142. sop_id = tokenizer.get_command("sop")
  143. mask_id = tokenizer.get_command("[gMASK]").Id if use_task_mask else tokenizer.get_command("[MASK]").Id
  144. cls_id = tokenizer.get_command("ENC")
  145. eos_id = tokenizer.get_command("eos")
  146. token = np.array(text, dtype=np.int64)
  147. blank_filling = mask_id in text
  148. if blank_filling:
  149. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  150. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  151. mask_position = text.index(mask_id) + 1
  152. context_length = len(token) + 2
  153. token = np.concatenate(([cls_id], token, [eos_id, sop_id]))
  154. else:
  155. if unidirectional:
  156. mask_position = 1
  157. context_length = 3
  158. token = np.concatenate(([cls_id, mask_id, eos_id, sop_id], token))
  159. else:
  160. mask_position = len(token) + 1
  161. context_length = len(token) + 3
  162. token = np.concatenate(([cls_id], token, [mask_id, eos_id, sop_id]))
  163. prefix_length = len(token) - context_length
  164. position_id = [list(range(context_length)) + [mask_position] * prefix_length,
  165. [0] * context_length + list(range(1, prefix_length + 1))]
  166. position_id = np.array(position_id, dtype=np.int64)
  167. target_position_id = [[mask_position] * max_gen_length,
  168. list(range(prefix_length + 1, prefix_length + max_gen_length + 1))]
  169. target_position_id = np.array(target_position_id, dtype=np.int64)
  170. attention_mask = np.tril(np.ones((len(token), len(token)), dtype=np.int64))
  171. if not unidirectional:
  172. attention_mask[: len(token) - 1, : len(token) - 1] = 1
  173. item = {
  174. "token": token,
  175. "position_id": position_id,
  176. "target_position_id": target_position_id,
  177. "attention_mask": attention_mask,
  178. "context_length": context_length,
  179. }
  180. return item
  181. class MultiChoiceTaskDataset(EvaluationDataset):
  182. config: MultiChoiceTaskConfig
  183. def __init__(self, path, config: MultiChoiceTaskConfig):
  184. self.is_single_token = True # set to False later in process_single_item func
  185. super().__init__(path, config)
  186. @property
  187. def has_collate_fn(self) -> bool:
  188. return True
  189. @staticmethod
  190. def collate_fn(samples):
  191. TILE = 32
  192. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  193. token_batch, position_id_batch, attention_mask_batch = [], [], []
  194. choices_batch, choice_target_ids_batch = [], []
  195. is_single_token = True
  196. for sample in samples:
  197. token, position_id, attention_mask = pad_batch(
  198. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  199. )
  200. token_batch.append(token)
  201. position_id_batch.append(position_id)
  202. attention_mask_batch.append(attention_mask)
  203. choices_batch.append(sample["choices"])
  204. choice_target_ids_batch.append(sample["choice_target_ids"])
  205. if isinstance(sample["choice_target_ids"], list):
  206. is_single_token = False
  207. return {
  208. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  209. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  210. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  211. "choices": choices_batch,
  212. "choice_target_ids": choice_target_ids_batch,
  213. "is_single_token": is_single_token,
  214. }
  215. def process_single_item(self, item):
  216. text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
  217. tgt_seq_length = sum([len(choice) for choice in choices])
  218. if tgt_seq_length == len(choices):
  219. # For single token, we only insert one [sop]
  220. tgt_seq_length = 1
  221. assert tgt_seq_length < self.config.max_seq_length
  222. if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
  223. text_length = self.config.max_seq_length - tgt_seq_length - 2
  224. text = text[len(text) - text_length : len(text)]
  225. assert not (
  226. self.mask_id in text and self.config.use_multitask_encoding
  227. ), "Unified multitask encoding don't support blank filling"
  228. if tgt_seq_length != 1:
  229. self.is_single_token = False
  230. return {
  231. "text": text,
  232. "choices": choices,
  233. "label": label,
  234. }
  235. @staticmethod
  236. def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
  237. tokenizer = get_tokenizer()
  238. sop_id = tokenizer.get_command("sop")
  239. mask_id = tokenizer.get_command("[MASK]")
  240. token = np.array(text, dtype=np.int64)
  241. target = np.array(text, dtype=np.int64)
  242. position_id = np.arange(len(text), dtype=np.int64)
  243. choice_target_id = []
  244. blank_filling = mask_id in text
  245. if not blank_filling:
  246. mask_position = len(token)
  247. token = np.concatenate((token, [mask_id]))
  248. target = np.concatenate((target, [mask_id]))
  249. position_id = np.concatenate((position_id, [mask_position]))
  250. else:
  251. mask_position = text.index(mask_id)
  252. division = len(token)
  253. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  254. for choice in choices:
  255. position_id = np.concatenate(
  256. (
  257. position_id,
  258. [mask_position] * len(choice)
  259. if blank_filling or not unified_multitask_encoding
  260. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  261. )
  262. )
  263. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  264. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  265. token = np.concatenate((token, [sop_id], choice[:-1]))
  266. target = np.concatenate((target, choice))
  267. if is_single_token:
  268. break
  269. attention_mask = block_diag(*attention_mask)
  270. attention_mask[: len(token), :division] = 1
  271. if is_single_token:
  272. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  273. item = {
  274. "token": token,
  275. "position_id": position_id,
  276. "attention_mask": attention_mask,
  277. "choices": choices,
  278. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  279. }
  280. return item
  281. def __getitem__(self, idx):
  282. item = self.data[idx]
  283. return self.build_multiple_choice_sample(
  284. item["text"],
  285. item["choices"],
  286. is_single_token=self.is_single_token,
  287. unified_multitask_encoding=self.config.use_multitask_encoding,
  288. )
  289. class LanguageModelTaskDataset(EvaluationDataset):
  290. config: LanguageModelTaskConfig
  291. def process_single_file(self, path):
  292. with open(os.path.join(path), "r", encoding="utf-8") as file:
  293. raw_text = file.read()
  294. tokens = self.tokenizer.tokenize(raw_text)
  295. self.data.append(
  296. {
  297. "raw_text": tokens,
  298. "num_original_tokens": len(raw_text.strip().split(" ")),
  299. "num_sequences": max(
  300. math.ceil(
  301. max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
  302. )
  303. + 1,
  304. 1,
  305. ),
  306. }
  307. )
  308. def process_single_item(self, item):
  309. pass
  310. def __len__(self):
  311. return self.data[0]["num_sequences"]
  312. def __getitem__(self, idx):
  313. start_idx = idx * self.config.generation_length
  314. end_idx = start_idx + self.config.max_seq_length - 1 # for additional [gMASK]
  315. tokens = self.data[0]["raw_text"][start_idx:end_idx]
  316. mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
  317. sop_id = self.tokenizer.get_command("sop")
  318. if idx == 0 or self.config.unidirectional:
  319. prompt, text = tokens[:1], tokens[1:]
  320. else:
  321. prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
  322. prompt, text = tokens[:prompt_length], tokens[prompt_length:]
  323. seq_length = len(prompt) + len(text) + 1
  324. attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
  325. attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
  326. return {
  327. "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
  328. "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
  329. "position_ids": np.arange(0, seq_length, dtype=np.int64),
  330. "attention_mask": attention_mask < 0.5,
  331. "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
  332. }