dataset.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import os
  2. import json
  3. import numpy as np
  4. import torch
  5. from abc import ABC, abstractmethod
  6. from scipy.linalg import block_diag
  7. from SwissArmyTransformer import get_tokenizer
  8. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig
  9. from .utils import get_tokenized_input
  10. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  11. attention_mask = np.pad(
  12. attention_mask,
  13. pad_width=((0, max_seq_length - len(tokens)),),
  14. mode="constant",
  15. constant_values=0,
  16. )
  17. tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
  18. position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
  19. return tokens, position_ids, attention_mask
  20. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  21. """
  22. Jsonlines of {
  23. "text": context
  24. "choices": [choice_id1,...], if not None, len(target) == 1
  25. "label": If generation task -1, else [0, len(choices))
  26. }
  27. If [MASK] not in context, will append [MASK] after text
  28. """
  29. def __init__(self, path, config: BaseConfig):
  30. self.path = path
  31. self.config = config
  32. self.max_seq_length = self.config.max_seq_length
  33. self.dtype = np.int64
  34. tokenizer = get_tokenizer(tokenizer_type="icetk-glm-130B")
  35. self.mask_id = tokenizer.get_command("[MASK]")
  36. self.gmask_id = tokenizer.get_command("[gMASK]")
  37. self.data = []
  38. with open(os.path.join(path), "r", encoding="utf-8") as file:
  39. for line in file:
  40. item = json.loads(line)
  41. self.data.append(self.process_single_item(item))
  42. @property
  43. def has_collate_fn(self) -> bool:
  44. return False
  45. def collate_fn(self, samples):
  46. return None
  47. @abstractmethod
  48. def process_single_item(self, item) -> dict:
  49. pass
  50. def __len__(self):
  51. return len(self.data)
  52. class GenerationTaskDataset(EvaluationDataset):
  53. config: GenerationTaskConfig
  54. def process_single_item(self, item):
  55. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  56. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  57. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  58. text = text[len(text) - text_length : len(text)]
  59. return {"text": text, "targets": targets}
  60. @staticmethod
  61. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  62. tokenizer = get_tokenizer()
  63. sop_id = tokenizer.get_command("sop")
  64. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  65. token = np.array(text, dtype=np.int64)
  66. blank_filling = mask_id in text
  67. if blank_filling:
  68. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  69. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  70. mask_position = text.index(mask_id)
  71. token = np.concatenate((token, [sop_id]))
  72. else:
  73. mask_position = len(token)
  74. if unidirectional:
  75. token = np.concatenate(([mask_id, sop_id], token))
  76. else:
  77. token = np.concatenate((token, [mask_id, sop_id]))
  78. context_length = len(token)
  79. max_seq_length = context_length + max_gen_length
  80. position_id = np.arange(0, max_seq_length, dtype=np.int64)
  81. if not use_task_mask:
  82. position_id[context_length - 1 :] = mask_position
  83. attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
  84. if not unidirectional:
  85. attention_mask[: context_length - 1, : context_length - 1] = 1
  86. item = {
  87. "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
  88. "position_ids": position_id,
  89. "attention_mask": attention_mask < 0.5,
  90. "context_length": context_length,
  91. }
  92. return item
  93. def __getitem__(self, idx):
  94. item = self.data[idx]
  95. sample = self.build_generation_sample(
  96. item["text"],
  97. max_gen_length=self.config.max_gen_length,
  98. use_task_mask=self.config.use_task_mask,
  99. unidirectional=self.config.unidirectional,
  100. )
  101. sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
  102. return sample
  103. class MultiChoiceTaskDataset(EvaluationDataset):
  104. config: MultiChoiceTaskConfig
  105. def __init__(self, path, config: MultiChoiceTaskConfig):
  106. self.is_single_token = True # set to False later in process_single_item func
  107. super().__init__(path, config)
  108. @property
  109. def has_collate_fn(self) -> bool:
  110. return True
  111. def collate_fn(self, samples):
  112. TILE = 32
  113. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  114. token_batch, position_id_batch, attention_mask_batch = [], [], []
  115. choices_batch, choice_target_ids_batch = [], []
  116. for sample in samples:
  117. token, position_id, attention_mask = pad_batch(
  118. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  119. )
  120. token_batch.append(token)
  121. position_id_batch.append(position_id)
  122. attention_mask_batch.append(attention_mask)
  123. choices_batch.append(sample["choices"])
  124. choice_target_ids_batch.append(sample["choice_target_ids"])
  125. return {
  126. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  127. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  128. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  129. "choices": choices_batch,
  130. "choice_target_ids": choice_target_ids_batch,
  131. "is_single_token": self.is_single_token,
  132. }
  133. def process_single_item(self, item):
  134. text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
  135. tgt_seq_length = sum([len(choice) for choice in choices])
  136. if tgt_seq_length == len(choices):
  137. # For single token, we only insert one [sop]
  138. tgt_seq_length = 1
  139. assert tgt_seq_length < self.config.max_seq_length
  140. if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
  141. text_length = self.config.max_seq_length - tgt_seq_length - 2
  142. text = text[len(text) - text_length : len(text)]
  143. assert not (
  144. self.mask_id in text and self.config.use_multitask_encoding
  145. ), "Unified multitask encoding don't support blank filling"
  146. if tgt_seq_length != 1:
  147. self.is_single_token = False
  148. return {
  149. "text": text,
  150. "choices": choices,
  151. "label": label,
  152. }
  153. @staticmethod
  154. def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
  155. tokenizer = get_tokenizer()
  156. sop_id = tokenizer.get_command("sop")
  157. mask_id = tokenizer.get_command("[MASK]")
  158. token = np.array(text, dtype=np.int64)
  159. target = np.array(text, dtype=np.int64)
  160. position_id = np.arange(len(text), dtype=np.int64)
  161. choice_target_id = []
  162. blank_filling = mask_id in text
  163. if not blank_filling:
  164. mask_position = len(token)
  165. token = np.concatenate((token, [mask_id]))
  166. target = np.concatenate((target, [mask_id]))
  167. position_id = np.concatenate((position_id, [mask_position]))
  168. else:
  169. mask_position = text.index(mask_id)
  170. division = len(token)
  171. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  172. for choice in choices:
  173. position_id = np.concatenate(
  174. (
  175. position_id,
  176. [mask_position] * len(choice)
  177. if blank_filling or not unified_multitask_encoding
  178. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  179. )
  180. )
  181. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  182. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  183. token = np.concatenate((token, [sop_id], choice[:-1]))
  184. target = np.concatenate((target, choice))
  185. if is_single_token:
  186. break
  187. attention_mask = block_diag(*attention_mask)
  188. attention_mask[: len(token), :division] = 1
  189. if is_single_token:
  190. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  191. item = {
  192. "token": token,
  193. "position_id": position_id,
  194. "attention_mask": attention_mask,
  195. "choices": choices,
  196. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  197. }
  198. return item
  199. def __getitem__(self, idx):
  200. item = self.data[idx]
  201. sample = self.build_multiple_choice_sample(
  202. item["text"],
  203. item["choices"],
  204. is_single_token=self.is_single_token,
  205. unified_multitask_encoding=self.config.use_multitask_encoding,
  206. )
  207. sample["label"] = item["label"]
  208. return sample