dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import os
  2. import math
  3. import json
  4. import numpy as np
  5. import torch
  6. from typing import List, Union
  7. from abc import ABC, abstractmethod
  8. from scipy.linalg import block_diag
  9. from itertools import accumulate
  10. from bisect import bisect_right
  11. from SwissArmyTransformer import get_tokenizer
  12. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
  13. from .utils import get_tokenized_input
  14. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  15. attention_mask = np.pad(
  16. attention_mask,
  17. pad_width=((0, max_seq_length - len(tokens)),),
  18. mode="constant",
  19. constant_values=0,
  20. )
  21. tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
  22. position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
  23. return tokens, position_ids, attention_mask
  24. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  25. """
  26. Jsonlines of {
  27. "text": context
  28. "choices": [choice_id1,...], if not None, len(target) == 1
  29. "label": If generation task -1, else [0, len(choices))
  30. }
  31. If [MASK] not in context, will append [MASK] after text
  32. """
  33. def __init__(self, path: Union[str, List[str]], config: BaseConfig):
  34. self.path = path if isinstance(path, list) else [path]
  35. self.config = config
  36. self.max_seq_length = self.config.max_seq_length
  37. self.dtype = np.int64
  38. self.tokenizer = get_tokenizer()
  39. self.mask_id = self.tokenizer.get_command("[MASK]")
  40. self.gmask_id = self.tokenizer.get_command("[gMASK]")
  41. self.data = []
  42. for p in self.path:
  43. self.process_single_file(p)
  44. @property
  45. def has_collate_fn(self) -> bool:
  46. return False
  47. def collate_fn(self, samples):
  48. return None
  49. def process_single_file(self, path):
  50. if not path.endswith("jsonl"):
  51. try:
  52. with open(os.path.join(path), "r", encoding="utf-8") as file:
  53. dataset = json.load(file)
  54. for item in dataset:
  55. self.data.extend(self.process_single_item(item))
  56. return
  57. except json.decoder.JSONDecodeError:
  58. pass
  59. with open(os.path.join(path), "r", encoding="utf-8") as file:
  60. for line in file:
  61. item = json.loads(line)
  62. self.data.extend(self.process_single_item(item))
  63. @abstractmethod
  64. def process_single_item(self, item, **kwargs) -> List[dict]:
  65. pass
  66. def __len__(self):
  67. return len(self.data)
  68. class GenerationTaskDataset(EvaluationDataset):
  69. config: GenerationTaskConfig
  70. def process_single_item(self, item, **kwargs):
  71. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  72. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  73. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  74. text = text[len(text) - text_length : len(text)]
  75. return [{"text": text, "targets": targets, **kwargs}]
  76. @property
  77. def has_collate_fn(self) -> bool:
  78. return True
  79. def collate_fn(self, samples):
  80. TILE = 32
  81. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  82. token_batch, position_id_batch, attention_mask_batch = [], [], []
  83. context_length_batch, target_position_id_batch = [], []
  84. for sample in samples:
  85. token, position_id, attention_mask = pad_batch(
  86. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  87. )
  88. token_batch.append(token)
  89. position_id_batch.append(position_id)
  90. attention_mask_batch.append(attention_mask)
  91. context_length_batch.append(sample['context_length'])
  92. target_position_id_batch.append(sample['target_position_id'])
  93. return {
  94. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  95. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  96. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  97. "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
  98. "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
  99. }
  100. @staticmethod
  101. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  102. tokenizer = get_tokenizer()
  103. sop_id = tokenizer.get_command("sop")
  104. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  105. token = np.array(text, dtype=np.int64)
  106. blank_filling = mask_id in text
  107. if blank_filling:
  108. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  109. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  110. mask_position = text.index(mask_id)
  111. token = np.concatenate((token, [sop_id]))
  112. else:
  113. mask_position = len(token)
  114. if unidirectional:
  115. token = np.concatenate(([mask_id, sop_id], token))
  116. else:
  117. token = np.concatenate((token, [mask_id, sop_id]))
  118. context_length = len(token)
  119. position_id = np.arange(0, context_length, dtype=np.int64)
  120. target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
  121. if not use_task_mask:
  122. position_id[context_length - 1:] = mask_position
  123. target_position_id[:] = mask_position
  124. attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
  125. if not unidirectional:
  126. attention_mask[: context_length - 1, : context_length - 1] = 1
  127. item = {
  128. "token": token,
  129. "position_id": position_id,
  130. "target_position_id": target_position_id,
  131. "attention_mask": attention_mask,
  132. "context_length": context_length,
  133. }
  134. return item
  135. def __getitem__(self, idx):
  136. item = self.data[idx]
  137. sample = self.build_generation_sample(
  138. item["text"],
  139. max_gen_length=self.config.max_gen_length,
  140. use_task_mask=self.config.use_task_mask,
  141. unidirectional=self.config.unidirectional,
  142. )
  143. if "target" in item:
  144. sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
  145. return sample
  146. class MultiChoiceTaskDataset(EvaluationDataset):
  147. config: MultiChoiceTaskConfig
  148. def __init__(self, path, config: MultiChoiceTaskConfig):
  149. self.is_single_token = True # set to False later in process_single_item func
  150. super().__init__(path, config)
  151. @property
  152. def has_collate_fn(self) -> bool:
  153. return True
  154. def collate_fn(self, samples):
  155. TILE = 32
  156. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  157. token_batch, position_id_batch, attention_mask_batch = [], [], []
  158. choices_batch, choice_target_ids_batch = [], []
  159. for sample in samples:
  160. token, position_id, attention_mask = pad_batch(
  161. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  162. )
  163. token_batch.append(token)
  164. position_id_batch.append(position_id)
  165. attention_mask_batch.append(attention_mask)
  166. choices_batch.append(sample["choices"])
  167. choice_target_ids_batch.append(sample["choice_target_ids"])
  168. return {
  169. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  170. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  171. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  172. "choices": choices_batch,
  173. "choice_target_ids": choice_target_ids_batch,
  174. "is_single_token": self.is_single_token,
  175. }
  176. def process_single_item(self, item, **kwargs):
  177. text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
  178. tgt_seq_length = sum([len(choice) for choice in choices])
  179. if tgt_seq_length == len(choices):
  180. # For single token, we only insert one [sop]
  181. tgt_seq_length = 1
  182. assert tgt_seq_length < self.config.max_seq_length
  183. if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
  184. text_length = self.config.max_seq_length - tgt_seq_length - 2
  185. text = text[len(text) - text_length : len(text)]
  186. assert not (
  187. self.mask_id in text and self.config.use_multitask_encoding
  188. ), "Unified multitask encoding don't support blank filling"
  189. if tgt_seq_length != 1:
  190. self.is_single_token = False
  191. return [{
  192. "text": text,
  193. "choices": choices,
  194. "label": label,
  195. **kwargs
  196. }]
  197. @staticmethod
  198. def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
  199. unidirectional=False, use_task_mask=False):
  200. tokenizer = get_tokenizer()
  201. sop_id = tokenizer.get_command("sop")
  202. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  203. token = np.array(text, dtype=np.int64)
  204. target = np.array(text, dtype=np.int64)
  205. position_id = np.arange(len(text), dtype=np.int64)
  206. choice_target_id = []
  207. blank_filling = mask_id in text
  208. if not blank_filling:
  209. if unidirectional:
  210. assert use_task_mask
  211. token = np.concatenate(([mask_id, sop_id], token[:-1]))
  212. target = np.concatenate(([mask_id, sop_id], target[:-1]))
  213. position_id = np.arange(len(token), dtype=np.int64)
  214. mask_position = len(token)
  215. else:
  216. mask_position = len(token)
  217. token = np.concatenate((token, [mask_id]))
  218. target = np.concatenate((target, [mask_id]))
  219. position_id = np.concatenate((position_id, [mask_position]))
  220. else:
  221. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  222. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  223. mask_position = text.index(mask_id)
  224. division = len(token)
  225. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  226. if unidirectional:
  227. attention_mask[0] = np.tril(attention_mask[0])
  228. for choice in choices:
  229. if not choice:
  230. choice = [tokenizer.get_command('eop')]
  231. position_id = np.concatenate(
  232. (
  233. position_id,
  234. [mask_position] * len(choice)
  235. if (blank_filling or not unified_multitask_encoding) and not use_task_mask
  236. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  237. )
  238. )
  239. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  240. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  241. if unidirectional:
  242. token = np.concatenate((token, [text[-1]], choice[:-1]))
  243. else:
  244. token = np.concatenate((token, [sop_id], choice[:-1]))
  245. target = np.concatenate((target, choice))
  246. if is_single_token:
  247. break
  248. attention_mask = block_diag(*attention_mask)
  249. attention_mask[division:, :division] = 1
  250. if is_single_token:
  251. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  252. item = {
  253. "token": token,
  254. "position_id": position_id,
  255. "attention_mask": attention_mask,
  256. "choices": choices,
  257. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  258. }
  259. return item
  260. def __getitem__(self, idx):
  261. item = self.data[idx]
  262. sample = self.build_multiple_choice_sample(
  263. item["text"],
  264. item["choices"],
  265. is_single_token=self.is_single_token,
  266. unified_multitask_encoding=self.config.use_multitask_encoding,
  267. unidirectional=self.config.unidirectional,
  268. use_task_mask=self.config.use_task_mask,
  269. )
  270. sample["label"] = item["label"]
  271. return sample
  272. class LanguageModelTaskDataset(EvaluationDataset):
  273. config: LanguageModelTaskConfig
  274. left_weights: List[int]
  275. weights: List[int]
  276. def process_single_file(self, path):
  277. num_sequences = []
  278. with open(os.path.join(path), "r", encoding="utf-8") as file:
  279. raw_text = file.read()
  280. tokens = self.tokenizer.tokenize(raw_text)
  281. self.data.append(
  282. {
  283. "raw_text": tokens,
  284. "num_original_tokens": len(raw_text.strip().split(" ")),
  285. "num_sequences": max(
  286. math.ceil(
  287. max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
  288. )
  289. + 1,
  290. 1,
  291. ),
  292. }
  293. )
  294. num_sequences.append(self.data[-1]["num_sequences"])
  295. self.weights = list(accumulate(num_sequences))
  296. self.left_weights = [0] + self.weights[:-1]
  297. def process_single_item(self, item):
  298. pass
  299. def __len__(self):
  300. return self.data[0]["num_sequences"]
  301. def __getitem__(self, idx):
  302. document_idx = bisect_right(self.weights, idx)
  303. idx = idx - self.left_weights[document_idx]
  304. start_idx = idx * self.config.generation_length
  305. end_idx = start_idx + self.config.max_seq_length - 1 # for additional [gMASK]
  306. tokens = self.data[document_idx]["raw_text"][start_idx:end_idx]
  307. mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
  308. sop_id = self.tokenizer.get_command("sop")
  309. if idx == 0 or self.config.unidirectional:
  310. prompt, text = [], tokens
  311. else:
  312. prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
  313. prompt, text = tokens[:prompt_length], tokens[prompt_length:]
  314. seq_length = len(prompt) + len(text) + 1
  315. attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
  316. attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
  317. return {
  318. "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
  319. "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
  320. "position_ids": np.arange(0, seq_length, dtype=np.int64),
  321. "attention_mask": attention_mask < 0.5,
  322. "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
  323. }