dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import os
  2. import json
  3. import numpy as np
  4. import torch
  5. from abc import ABC, abstractmethod
  6. from scipy.linalg import block_diag
  7. from SwissArmyTransformer import get_tokenizer
  8. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig
  9. from .utils import get_tokenized_input
  10. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  11. attention_mask = np.pad(
  12. attention_mask,
  13. pad_width=((0, max_seq_length - len(tokens)),),
  14. mode="constant",
  15. constant_values=0,
  16. )
  17. tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
  18. position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
  19. return tokens, position_ids, attention_mask
  20. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  21. """
  22. Jsonlines of {
  23. "text": context
  24. "choices": [choice_id1,...], if not None, len(target) == 1
  25. "label": If generation task -1, else [0, len(choices))
  26. }
  27. If [MASK] not in context, will append [MASK] after text
  28. """
  29. def __init__(self, path, config: BaseConfig):
  30. self.path = path
  31. self.config = config
  32. self.max_seq_length = self.config.max_seq_length
  33. self.dtype = np.int64
  34. tokenizer = get_tokenizer(tokenizer_type="icetk-glm-130B")
  35. self.mask_id = tokenizer.get_command("[MASK]")
  36. self.gmask_id = tokenizer.get_command("[gMASK]")
  37. self.data = []
  38. with open(os.path.join(path), "r", encoding="utf-8") as file:
  39. for line in file:
  40. item = json.loads(line)
  41. self.data.append(self.process_single_item(item))
  42. @property
  43. def has_collate_fn(self) -> bool:
  44. return False
  45. def collate_fn(self, samples):
  46. return None
  47. @abstractmethod
  48. def process_single_item(self, item) -> dict:
  49. pass
  50. def __len__(self):
  51. return len(self.data)
  52. class GenerationTaskDataset(EvaluationDataset):
  53. config: GenerationTaskConfig
  54. def process_single_item(self, item):
  55. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  56. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  57. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  58. text = text[len(text) - text_length : len(text)]
  59. return {"text": text, "targets": targets}
  60. @property
  61. def has_collate_fn(self) -> bool:
  62. return True
  63. def collate_fn(self, samples):
  64. TILE = 32
  65. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  66. token_batch, position_id_batch, attention_mask_batch = [], [], []
  67. context_length_batch, target_position_id_batch = [], []
  68. for sample in samples:
  69. token, position_id, attention_mask = pad_batch(
  70. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  71. )
  72. token_batch.append(token)
  73. position_id_batch.append(position_id)
  74. attention_mask_batch.append(attention_mask)
  75. context_length_batch.append(sample['context_length'])
  76. target_position_id_batch.append(sample['target_position_id'])
  77. return {
  78. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  79. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  80. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  81. "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
  82. "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
  83. }
  84. @staticmethod
  85. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  86. tokenizer = get_tokenizer()
  87. sop_id = tokenizer.get_command("sop")
  88. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  89. token = np.array(text, dtype=np.int64)
  90. blank_filling = mask_id in text
  91. if blank_filling:
  92. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  93. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  94. mask_position = text.index(mask_id)
  95. token = np.concatenate((token, [sop_id]))
  96. else:
  97. mask_position = len(token)
  98. if unidirectional:
  99. token = np.concatenate(([mask_id, sop_id], token))
  100. else:
  101. token = np.concatenate((token, [mask_id, sop_id]))
  102. context_length = len(token)
  103. position_id = np.arange(0, context_length, dtype=np.int64)
  104. target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
  105. if not use_task_mask:
  106. position_id[context_length - 1:] = mask_position
  107. target_position_id[:] = mask_position
  108. attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
  109. if not unidirectional:
  110. attention_mask[: context_length - 1, : context_length - 1] = 1
  111. item = {
  112. "token": token,
  113. "position_id": position_id,
  114. "target_position_id": target_position_id,
  115. "attention_mask": attention_mask,
  116. "context_length": context_length,
  117. }
  118. return item
  119. def __getitem__(self, idx):
  120. item = self.data[idx]
  121. sample = self.build_generation_sample(
  122. item["text"],
  123. max_gen_length=self.config.max_gen_length,
  124. use_task_mask=self.config.use_task_mask,
  125. unidirectional=self.config.unidirectional,
  126. )
  127. sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
  128. return sample
  129. class MultiChoiceTaskDataset(EvaluationDataset):
  130. config: MultiChoiceTaskConfig
  131. def __init__(self, path, config: MultiChoiceTaskConfig):
  132. self.is_single_token = True # set to False later in process_single_item func
  133. super().__init__(path, config)
  134. @property
  135. def has_collate_fn(self) -> bool:
  136. return True
  137. def collate_fn(self, samples):
  138. TILE = 32
  139. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  140. token_batch, position_id_batch, attention_mask_batch = [], [], []
  141. choices_batch, choice_target_ids_batch = [], []
  142. for sample in samples:
  143. token, position_id, attention_mask = pad_batch(
  144. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  145. )
  146. token_batch.append(token)
  147. position_id_batch.append(position_id)
  148. attention_mask_batch.append(attention_mask)
  149. choices_batch.append(sample["choices"])
  150. choice_target_ids_batch.append(sample["choice_target_ids"])
  151. return {
  152. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  153. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  154. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  155. "choices": choices_batch,
  156. "choice_target_ids": choice_target_ids_batch,
  157. "is_single_token": self.is_single_token,
  158. }
  159. def process_single_item(self, item):
  160. text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
  161. tgt_seq_length = sum([len(choice) for choice in choices])
  162. if tgt_seq_length == len(choices):
  163. # For single token, we only insert one [sop]
  164. tgt_seq_length = 1
  165. assert tgt_seq_length < self.config.max_seq_length
  166. if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
  167. text_length = self.config.max_seq_length - tgt_seq_length - 2
  168. text = text[len(text) - text_length : len(text)]
  169. assert not (
  170. self.mask_id in text and self.config.use_multitask_encoding
  171. ), "Unified multitask encoding don't support blank filling"
  172. if tgt_seq_length != 1:
  173. self.is_single_token = False
  174. return {
  175. "text": text,
  176. "choices": choices,
  177. "label": label,
  178. }
  179. @staticmethod
  180. def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
  181. tokenizer = get_tokenizer()
  182. sop_id = tokenizer.get_command("sop")
  183. mask_id = tokenizer.get_command("[MASK]")
  184. token = np.array(text, dtype=np.int64)
  185. target = np.array(text, dtype=np.int64)
  186. position_id = np.arange(len(text), dtype=np.int64)
  187. choice_target_id = []
  188. blank_filling = mask_id in text
  189. if not blank_filling:
  190. mask_position = len(token)
  191. token = np.concatenate((token, [mask_id]))
  192. target = np.concatenate((target, [mask_id]))
  193. position_id = np.concatenate((position_id, [mask_position]))
  194. else:
  195. mask_position = text.index(mask_id)
  196. division = len(token)
  197. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  198. for choice in choices:
  199. position_id = np.concatenate(
  200. (
  201. position_id,
  202. [mask_position] * len(choice)
  203. if blank_filling or not unified_multitask_encoding
  204. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  205. )
  206. )
  207. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  208. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  209. token = np.concatenate((token, [sop_id], choice[:-1]))
  210. target = np.concatenate((target, choice))
  211. if is_single_token:
  212. break
  213. attention_mask = block_diag(*attention_mask)
  214. attention_mask[: len(token), :division] = 1
  215. if is_single_token:
  216. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  217. item = {
  218. "token": token,
  219. "position_id": position_id,
  220. "attention_mask": attention_mask,
  221. "choices": choices,
  222. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  223. }
  224. return item
  225. def __getitem__(self, idx):
  226. item = self.data[idx]
  227. sample = self.build_multiple_choice_sample(
  228. item["text"],
  229. item["choices"],
  230. is_single_token=self.is_single_token,
  231. unified_multitask_encoding=self.config.use_multitask_encoding,
  232. )
  233. sample["label"] = item["label"]
  234. return sample