dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import os
  2. import json
  3. import numpy as np
  4. import torch
  5. from typing import List
  6. from abc import ABC, abstractmethod
  7. from scipy.linalg import block_diag
  8. from SwissArmyTransformer import get_tokenizer
  9. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig
  10. from .utils import get_tokenized_input
  11. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  12. attention_mask = np.pad(
  13. attention_mask,
  14. pad_width=((0, max_seq_length - len(tokens)),),
  15. mode="constant",
  16. constant_values=0,
  17. )
  18. tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
  19. position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
  20. return tokens, position_ids, attention_mask
  21. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  22. """
  23. Jsonlines of {
  24. "text": context
  25. "choices": [choice_id1,...], if not None, len(target) == 1
  26. "label": If generation task -1, else [0, len(choices))
  27. }
  28. If [MASK] not in context, will append [MASK] after text
  29. """
  30. def __init__(self, path, config: BaseConfig):
  31. self.path = path
  32. self.config = config
  33. self.max_seq_length = self.config.max_seq_length
  34. self.dtype = np.int64
  35. tokenizer = get_tokenizer(tokenizer_type="icetk-glm-130B")
  36. self.mask_id = tokenizer.get_command("[MASK]")
  37. self.gmask_id = tokenizer.get_command("[gMASK]")
  38. self.data = []
  39. for p in self.path:
  40. self.process_single_file(p)
  41. @property
  42. def has_collate_fn(self) -> bool:
  43. return False
  44. def collate_fn(self, samples):
  45. return None
  46. def process_single_file(self, path):
  47. if not path.endswith("jsonl"):
  48. try:
  49. with open(os.path.join(path), "r", encoding="utf-8") as file:
  50. dataset = json.load(file)
  51. for item in dataset:
  52. self.data.extend(self.process_single_item(item))
  53. return
  54. except json.decoder.JSONDecodeError:
  55. pass
  56. with open(os.path.join(path), "r", encoding="utf-8") as file:
  57. for line in file:
  58. item = json.loads(line)
  59. self.data.extend(self.process_single_item(item))
  60. @abstractmethod
  61. def process_single_item(self, item, **kwargs) -> List[dict]:
  62. pass
  63. def __len__(self):
  64. return len(self.data)
  65. class GenerationTaskDataset(EvaluationDataset):
  66. config: GenerationTaskConfig
  67. def process_single_item(self, item, **kwargs):
  68. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  69. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  70. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  71. text = text[len(text) - text_length : len(text)]
  72. return [{"text": text, "targets": targets, **kwargs}]
  73. @staticmethod
  74. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  75. tokenizer = get_tokenizer()
  76. sop_id = tokenizer.get_command("sop")
  77. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  78. token = np.array(text, dtype=np.int64)
  79. blank_filling = mask_id in text
  80. if blank_filling:
  81. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  82. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  83. mask_position = text.index(mask_id)
  84. token = np.concatenate((token, [sop_id]))
  85. else:
  86. mask_position = len(token)
  87. if unidirectional:
  88. token = np.concatenate(([mask_id, sop_id], token))
  89. else:
  90. token = np.concatenate((token, [mask_id, sop_id]))
  91. context_length = len(token)
  92. max_seq_length = context_length + max_gen_length
  93. position_id = np.arange(0, max_seq_length, dtype=np.int64)
  94. if not use_task_mask:
  95. position_id[context_length - 1 :] = mask_position
  96. attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
  97. if not unidirectional:
  98. attention_mask[: context_length - 1, : context_length - 1] = 1
  99. item = {
  100. "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
  101. "position_ids": position_id,
  102. "attention_mask": attention_mask < 0.5,
  103. "context_length": context_length,
  104. }
  105. return item
  106. def __getitem__(self, idx):
  107. item = self.data[idx]
  108. sample = self.build_generation_sample(
  109. item["text"],
  110. max_gen_length=self.config.max_gen_length,
  111. use_task_mask=self.config.use_task_mask,
  112. unidirectional=self.config.unidirectional,
  113. )
  114. if "target" in item:
  115. sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
  116. return sample
  117. class MultiChoiceTaskDataset(EvaluationDataset):
  118. config: MultiChoiceTaskConfig
  119. def __init__(self, path, config: MultiChoiceTaskConfig):
  120. self.is_single_token = True # set to False later in process_single_item func
  121. super().__init__(path, config)
  122. @property
  123. def has_collate_fn(self) -> bool:
  124. return True
  125. def collate_fn(self, samples):
  126. TILE = 32
  127. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  128. token_batch, position_id_batch, attention_mask_batch = [], [], []
  129. choices_batch, choice_target_ids_batch = [], []
  130. for sample in samples:
  131. token, position_id, attention_mask = pad_batch(
  132. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  133. )
  134. token_batch.append(token)
  135. position_id_batch.append(position_id)
  136. attention_mask_batch.append(attention_mask)
  137. choices_batch.append(sample["choices"])
  138. choice_target_ids_batch.append(sample["choice_target_ids"])
  139. return {
  140. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  141. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  142. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  143. "choices": choices_batch,
  144. "choice_target_ids": choice_target_ids_batch,
  145. "is_single_token": self.is_single_token,
  146. }
  147. def process_single_item(self, item, **kwargs):
  148. text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
  149. tgt_seq_length = sum([len(choice) for choice in choices])
  150. if tgt_seq_length == len(choices):
  151. # For single token, we only insert one [sop]
  152. tgt_seq_length = 1
  153. assert tgt_seq_length < self.config.max_seq_length
  154. if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
  155. text_length = self.config.max_seq_length - tgt_seq_length - 2
  156. text = text[len(text) - text_length : len(text)]
  157. assert not (
  158. self.mask_id in text and self.config.use_multitask_encoding
  159. ), "Unified multitask encoding don't support blank filling"
  160. if tgt_seq_length != 1:
  161. self.is_single_token = False
  162. return [{
  163. "text": text,
  164. "choices": choices,
  165. "label": label,
  166. **kwargs
  167. }]
  168. @staticmethod
  169. def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
  170. unidirectional=False, use_task_mask=False):
  171. tokenizer = get_tokenizer()
  172. sop_id = tokenizer.get_command("sop")
  173. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  174. token = np.array(text, dtype=np.int64)
  175. target = np.array(text, dtype=np.int64)
  176. position_id = np.arange(len(text), dtype=np.int64)
  177. choice_target_id = []
  178. blank_filling = mask_id in text
  179. if not blank_filling:
  180. if unidirectional:
  181. assert use_task_mask
  182. token = np.concatenate(([mask_id, sop_id], token[:-1]))
  183. target = np.concatenate(([mask_id, sop_id], target[:-1]))
  184. position_id = np.arange(len(token), dtype=np.int64)
  185. mask_position = len(token)
  186. else:
  187. mask_position = len(token)
  188. token = np.concatenate((token, [mask_id]))
  189. target = np.concatenate((target, [mask_id]))
  190. position_id = np.concatenate((position_id, [mask_position]))
  191. else:
  192. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  193. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  194. mask_position = text.index(mask_id)
  195. division = len(token)
  196. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  197. if unidirectional:
  198. attention_mask[0] = np.tril(attention_mask[0])
  199. for choice in choices:
  200. if not choice:
  201. choice = [tokenizer.get_command('eop')]
  202. position_id = np.concatenate(
  203. (
  204. position_id,
  205. [mask_position] * len(choice)
  206. if (blank_filling or not unified_multitask_encoding) and not use_task_mask
  207. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  208. )
  209. )
  210. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  211. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  212. if unidirectional:
  213. token = np.concatenate((token, [text[-1]], choice[:-1]))
  214. else:
  215. token = np.concatenate((token, [sop_id], choice[:-1]))
  216. target = np.concatenate((target, choice))
  217. if is_single_token:
  218. break
  219. attention_mask = block_diag(*attention_mask)
  220. attention_mask[division:, :division] = 1
  221. if is_single_token:
  222. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  223. item = {
  224. "token": token,
  225. "position_id": position_id,
  226. "attention_mask": attention_mask,
  227. "choices": choices,
  228. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  229. }
  230. return item
  231. def __getitem__(self, idx):
  232. item = self.data[idx]
  233. sample = self.build_multiple_choice_sample(
  234. item["text"],
  235. item["choices"],
  236. is_single_token=self.is_single_token,
  237. unified_multitask_encoding=self.config.use_multitask_encoding,
  238. unidirectional=self.config.unidirectional,
  239. use_task_mask=self.config.use_task_mask,
  240. )
  241. sample["label"] = item["label"]
  242. return sample