dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import os
  2. import math
  3. import json
  4. import numpy as np
  5. import torch
  6. from typing import List, Union
  7. from abc import ABC, abstractmethod
  8. from scipy.linalg import block_diag
  9. from itertools import accumulate
  10. from bisect import bisect_right
  11. from SwissArmyTransformer import get_tokenizer
  12. from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
  13. from .utils import get_tokenized_input
  14. from .model import ModelForEvaluation
  15. def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
  16. pad_length = max_seq_length - len(tokens)
  17. attention_mask = np.pad(
  18. attention_mask,
  19. pad_width=((0, pad_length),),
  20. mode="constant",
  21. constant_values=0,
  22. )
  23. tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))
  24. position_ids = np.concatenate(
  25. (position_ids, np.zeros_like(position_ids[..., -1:], dtype=np.int64).repeat(pad_length, -1)), axis=-1
  26. )
  27. return tokens, position_ids, attention_mask
  28. class EvaluationDataset(torch.utils.data.Dataset, ABC):
  29. """
  30. Jsonlines of {
  31. "text": context
  32. "choices": [choice_id1,...], if not None, len(target) == 1
  33. "label": If generation task -1, else [0, len(choices))
  34. }
  35. If [MASK] not in context, will append [MASK] after text
  36. """
  37. def __init__(self, path: Union[str, List[str]], model: ModelForEvaluation, config: BaseConfig):
  38. self.path = path if isinstance(path, list) else [path]
  39. self.model = model
  40. self.config = config
  41. self.max_seq_length = self.config.max_seq_length
  42. self.dtype = np.int64
  43. self.tokenizer = get_tokenizer()
  44. self.mask_id = self.tokenizer.get_command("[MASK]")
  45. self.gmask_id = self.tokenizer.get_command("[gMASK]")
  46. self.data = []
  47. for p in self.path:
  48. self.process_single_file(p)
  49. @property
  50. def has_collate_fn(self) -> bool:
  51. return False
  52. def collate_fn(self, samples):
  53. return None
  54. def process_single_file(self, path):
  55. with open(os.path.join(path), "r", encoding="utf-8") as file:
  56. for line in file:
  57. item = json.loads(line)
  58. self.data.extend(self.process_single_item(item))
  59. @abstractmethod
  60. def process_single_item(self, item, **kwargs) -> List[dict]:
  61. pass
  62. def __len__(self):
  63. return len(self.data)
  64. class GenerationTaskDataset(EvaluationDataset):
  65. config: GenerationTaskConfig
  66. def process_single_item(self, item, **kwargs):
  67. text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
  68. if len(targets) and (not isinstance(targets[0], list)):
  69. targets = [targets]
  70. if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
  71. text_length = self.config.max_seq_length - self.config.max_gen_length - 2
  72. text = text[len(text) - text_length : len(text)]
  73. return [{"text": text, "targets": targets, **kwargs}]
  74. @property
  75. def has_collate_fn(self) -> bool:
  76. return True
  77. def collate_fn(self, samples):
  78. TILE = 32
  79. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  80. token_batch, position_id_batch, attention_mask_batch = [], [], []
  81. context_length_batch, target_position_id_batch = [], []
  82. for sample in samples:
  83. token, position_id, attention_mask = pad_batch(
  84. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  85. )
  86. token_batch.append(token)
  87. position_id_batch.append(position_id)
  88. attention_mask_batch.append(attention_mask)
  89. context_length_batch.append(sample["context_length"])
  90. target_position_id_batch.append(sample["target_position_id"])
  91. return {
  92. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  93. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  94. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  95. "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
  96. "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
  97. }
  98. @staticmethod
  99. def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
  100. tokenizer = get_tokenizer()
  101. sop_id = tokenizer.get_command("sop")
  102. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  103. token = np.array(text, dtype=np.int64)
  104. blank_filling = mask_id in text
  105. if blank_filling:
  106. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  107. assert not use_task_mask, "Unidirectional attention doesn't support task mask"
  108. mask_position = text.index(mask_id)
  109. token = np.concatenate((token, [sop_id]))
  110. else:
  111. mask_position = len(token)
  112. if unidirectional:
  113. token = np.concatenate(([mask_id, sop_id], token))
  114. else:
  115. token = np.concatenate((token, [mask_id, sop_id]))
  116. context_length = len(token)
  117. position_id = np.arange(0, context_length, dtype=np.int64)
  118. target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
  119. if not use_task_mask:
  120. position_id[context_length - 1 :] = mask_position
  121. target_position_id[:] = mask_position
  122. attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
  123. if not unidirectional:
  124. attention_mask[: context_length - 1, : context_length - 1] = 1
  125. item = {
  126. "token": token,
  127. "position_id": position_id,
  128. "target_position_id": target_position_id,
  129. "attention_mask": attention_mask,
  130. "context_length": context_length,
  131. }
  132. return item
  133. def __getitem__(self, idx):
  134. item = self.data[idx]
  135. sample = self.build_generation_sample(
  136. item["text"],
  137. max_gen_length=self.config.max_gen_length,
  138. use_task_mask=self.config.use_task_mask,
  139. unidirectional=self.config.unidirectional,
  140. )
  141. return sample
  142. class MultiChoiceTaskDataset(EvaluationDataset):
  143. config: MultiChoiceTaskConfig
  144. def __init__(self, path: Union[str, List[str]], model: ModelForEvaluation, config: BaseConfig):
  145. self.is_single_token = True # set to False later in process_single_item func
  146. super().__init__(path, model, config)
  147. @property
  148. def has_collate_fn(self) -> bool:
  149. return True
  150. def collate_fn(self, samples):
  151. TILE = 32
  152. length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
  153. token_batch, position_id_batch, attention_mask_batch = [], [], []
  154. choices_batch, choice_target_ids_batch = [], []
  155. for sample in samples:
  156. token, position_id, attention_mask = pad_batch(
  157. sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
  158. )
  159. token_batch.append(token)
  160. position_id_batch.append(position_id)
  161. attention_mask_batch.append(attention_mask)
  162. choices_batch.append(sample["choices"])
  163. choice_target_ids_batch.append(sample["choice_target_ids"])
  164. return {
  165. "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
  166. "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
  167. "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
  168. "choices": choices_batch,
  169. "choice_target_ids": choice_target_ids_batch,
  170. "is_single_token": self.is_single_token,
  171. }
  172. def process_single_item(self, item, **kwargs):
  173. text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
  174. tgt_seq_length = sum([len(choice) for choice in choices])
  175. if tgt_seq_length == len(choices):
  176. # For single token, we only insert one [sop]
  177. tgt_seq_length = 1
  178. assert tgt_seq_length < self.config.max_seq_length
  179. if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
  180. text_length = self.config.max_seq_length - tgt_seq_length - 2
  181. text = text[len(text) - text_length : len(text)]
  182. assert not (
  183. self.mask_id in text and self.config.use_multitask_encoding
  184. ), "Unified multitask encoding don't support blank filling"
  185. if tgt_seq_length != 1:
  186. self.is_single_token = False
  187. return [{"text": text, "choices": choices, "label": label, **kwargs}]
  188. def __getitem__(self, idx):
  189. item = self.data[idx]
  190. sample = self.model.build_multiple_choice_sample(
  191. item["text"],
  192. item["choices"],
  193. is_single_token=self.is_single_token,
  194. unified_multitask_encoding=self.config.use_multitask_encoding,
  195. unidirectional=self.config.unidirectional,
  196. use_task_mask=self.config.use_task_mask,
  197. )
  198. return sample
  199. class LanguageModelTaskDataset(EvaluationDataset):
  200. config: LanguageModelTaskConfig
  201. left_weights: List[int]
  202. weights: List[int]
  203. def process_single_file(self, path):
  204. num_sequences = []
  205. with open(os.path.join(path), "r", encoding="utf-8") as file:
  206. raw_text = file.read()
  207. tokens = self.tokenizer.tokenize(raw_text)
  208. self.data.append(
  209. {
  210. "raw_text": tokens,
  211. "num_original_tokens": len(raw_text.strip().split(" ")),
  212. "num_sequences": max(
  213. math.ceil(
  214. max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
  215. )
  216. + 1,
  217. 1,
  218. ),
  219. }
  220. )
  221. num_sequences.append(self.data[-1]["num_sequences"])
  222. self.weights = list(accumulate(num_sequences))
  223. self.left_weights = [0] + self.weights[:-1]
  224. def process_single_item(self, item):
  225. pass
  226. def __len__(self):
  227. return self.data[0]["num_sequences"]
  228. def __getitem__(self, idx):
  229. document_idx = bisect_right(self.weights, idx)
  230. idx = idx - self.left_weights[document_idx]
  231. start_idx = idx * self.config.generation_length
  232. end_idx = start_idx + self.config.max_seq_length - 1 # for additional [gMASK]
  233. tokens = self.data[document_idx]["raw_text"][start_idx:end_idx]
  234. return self.model.build_language_model_sample(
  235. tokens,
  236. is_first_segment=idx == 0,
  237. max_seq_length=self.config.max_seq_length,
  238. generation_length=self.config.generation_length,
  239. unidirectional=self.config.unidirectional,
  240. use_gmask=self.config.use_task_mask,
  241. )