2
0

model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import numpy as np
  2. import torch
  3. from typing import List, Union
  4. from scipy.linalg import block_diag
  5. from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
  6. from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
  7. from SwissArmyTransformer import get_tokenizer
  8. def batch_filling_sequence(
  9. model,
  10. seqs,
  11. context_lengths,
  12. strategy,
  13. max_memory_length=100000,
  14. get_masks_and_position_ids=get_masks_and_position_ids_default,
  15. mems=None,
  16. **kw_args
  17. ):
  18. """
  19. seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
  20. mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
  21. cache, should be first mems.shape[1] parts of context_tokens.
  22. mems are the first-level citizens here, but we don't assume what is memorized.
  23. input mems are used when multi-phase generation.
  24. """
  25. assert len(seqs.shape) == 2
  26. # building the initial tokens, attention_mask, and position_ids
  27. batch_size, context_length = seqs.shape
  28. seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
  29. tokens = seqs[..., :context_length]
  30. if attention_mask.dtype != torch.bool:
  31. attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
  32. # initialize generation
  33. counter = context_length - 1 # Last fixed index is ``counter''
  34. index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
  35. num_beams = 1
  36. # step-by-step generation
  37. while counter < seqs.shape[1] - 1:
  38. # Now, we want to generate seq[counter + 1],
  39. # token[:, index: counter+1] needs forwarding.
  40. # forward
  41. tokens = tokens.reshape(batch_size * num_beams, -1)
  42. mems = (
  43. mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
  44. if mems is not None
  45. else None
  46. )
  47. logits, *output_per_layers = model(
  48. tokens[:, index:],
  49. position_ids[..., index : counter + 1],
  50. attention_mask[..., index : counter + 1, : counter + 1], # TODO memlen
  51. mems=mems,
  52. **kw_args
  53. )
  54. mem_kv = [o["mem_kv"] for o in output_per_layers]
  55. mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
  56. if counter == context_length - 1:
  57. logits = logits[torch.arange(batch_size), context_lengths - 1]
  58. else:
  59. logits = logits[:, -1]
  60. counter += 1
  61. index = counter
  62. # if torch.distributed.get_rank() == 0:
  63. # print(f"counter: {counter}: logits: {logits.float().abs().mean()}")
  64. # sampling
  65. logits = logits.reshape(batch_size, num_beams, -1)
  66. tokens = tokens.reshape(batch_size, num_beams, -1)
  67. mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
  68. tokens, mems = strategy.forward(logits, tokens, mems)
  69. if len(tokens.shape) == 3 and num_beams == 1:
  70. num_beams = tokens.shape[1]
  71. position_ids = (
  72. position_ids.unsqueeze(1)
  73. .expand((batch_size, num_beams) + position_ids.shape[1:])
  74. .reshape((batch_size * num_beams,) + position_ids.shape[1:])
  75. )
  76. attention_mask_shape = attention_mask.shape[-3:]
  77. attention_mask = (
  78. attention_mask.unsqueeze(1)
  79. .expand(batch_size, num_beams, -1, -1, -1)
  80. .reshape(batch_size * num_beams, *attention_mask_shape)
  81. )
  82. if strategy.is_done:
  83. break
  84. return strategy.finalize(tokens, mems)
  85. class ModelForEvaluation(torch.nn.Module):
  86. def __init__(self, model, position_encoding_2d):
  87. super().__init__()
  88. self.model = model
  89. self.position_encoding_2d = position_encoding_2d
  90. self.device = next(self.model.parameters()).device
  91. @staticmethod
  92. def process_data(batch, device):
  93. return (
  94. batch["tokens"].to(device=device).long(),
  95. batch["position_ids"].to(device=device).long(),
  96. batch["attention_mask"].to(device=device).bool().unsqueeze(1),
  97. )
  98. def build_multiple_choice_sample(
  99. self,
  100. text,
  101. choices,
  102. is_single_token,
  103. unified_multitask_encoding=False,
  104. unidirectional=False,
  105. use_task_mask=False,
  106. ):
  107. tokenizer = get_tokenizer()
  108. sop_id = tokenizer.get_command("sop")
  109. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  110. token = np.array(text, dtype=np.int64)
  111. target = np.array(text, dtype=np.int64)
  112. position_id = np.arange(len(text), dtype=np.int64)
  113. block_position_id = np.zeros(len(text), dtype=np.int64)
  114. choice_target_id = []
  115. blank_filling = mask_id in text
  116. if not blank_filling:
  117. if unidirectional:
  118. assert use_task_mask, "Unidirectional attention only support gMASK"
  119. token = np.concatenate(([mask_id, sop_id], token[:-1]))
  120. target = np.concatenate(([mask_id, sop_id], target[:-1]))
  121. position_id = np.zeros(len(token), dtype=np.int64)
  122. if self.position_encoding_2d:
  123. block_position_id = np.arange(len(token), dtype=np.int64)
  124. mask_position = len(token)
  125. else:
  126. mask_position = len(token)
  127. token = np.concatenate((token, [mask_id]))
  128. target = np.concatenate((target, [mask_id]))
  129. position_id = np.arange(len(token), dtype=np.int64)
  130. if self.position_encoding_2d:
  131. block_position_id = np.zeros(len(token), dtype=np.int64)
  132. else:
  133. assert not unidirectional, "Unidirectional attention doesn't support blank filling"
  134. assert not use_task_mask, "Blank filling only support MASK"
  135. mask_position = text.index(mask_id)
  136. division = len(token)
  137. attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
  138. if unidirectional:
  139. attention_mask[0] = np.tril(attention_mask[0])
  140. for choice in choices:
  141. if not choice:
  142. choice = [tokenizer.get_command("eop")]
  143. target = np.concatenate((target, choice))
  144. choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
  145. attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
  146. if unidirectional:
  147. if self.position_encoding_2d:
  148. position_id = np.concatenate((position_id, [0] * len(choice)))
  149. block_position_id = np.concatenate(
  150. (block_position_id, np.arange(mask_position, mask_position + len(choice), dtype=np.int64))
  151. )
  152. else:
  153. position_id = np.concatenate(
  154. (
  155. position_id,
  156. np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  157. )
  158. )
  159. token = np.concatenate((token, [text[-1]], choice[:-1]))
  160. else:
  161. if self.position_encoding_2d:
  162. position_id = np.concatenate((position_id, [mask_position] * len(choice)))
  163. block_position_id = np.concatenate(
  164. (block_position_id, np.arange(1, 1 + len(choice), dtype=np.int64))
  165. )
  166. else:
  167. position_id = np.concatenate(
  168. (
  169. position_id,
  170. [mask_position] * len(choice)
  171. if (blank_filling or not unified_multitask_encoding) and not use_task_mask
  172. else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
  173. )
  174. )
  175. token = np.concatenate((token, [sop_id], choice[:-1]))
  176. if is_single_token:
  177. break
  178. attention_mask = block_diag(*attention_mask)
  179. attention_mask[division:, :division] = 1
  180. if is_single_token:
  181. choices = np.array(choices, dtype=np.int64).squeeze().tolist()
  182. if self.position_encoding_2d:
  183. position_id = np.stack((position_id, block_position_id), axis=0)
  184. item = {
  185. "token": token,
  186. "position_id": position_id,
  187. "attention_mask": attention_mask,
  188. "choices": choices,
  189. "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
  190. }
  191. return item
  192. def cond_log_prob(self, batch) -> List[List[float]]:
  193. """
  194. @return: Conditional log probability of each option
  195. """
  196. tokens, position_ids, attention_mask = self.process_data(batch, self.device)
  197. choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
  198. is_single_token = batch["is_single_token"]
  199. self.model.eval()
  200. with torch.no_grad():
  201. logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
  202. logits_batch = torch.nn.functional.log_softmax(logits, dim=-1)
  203. # output: [b, sq, vocab]
  204. log_probs = []
  205. # if torch.distributed.get_rank() == 0:
  206. # import pdb
  207. #
  208. # pdb.set_trace()
  209. # torch.distributed.barrier()
  210. if is_single_token: # Single token
  211. for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  212. log_probs.append(logits[choice_target_ids[0], choices].tolist())
  213. else: # Multi token
  214. for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  215. log_probs_single = []
  216. for choice, choice_target_id in zip(choices, choice_target_ids):
  217. tmp = output[choice_target_id, choice]
  218. log_probs_single.append(tmp.sum().tolist())
  219. log_probs.append(log_probs_single)
  220. return log_probs
  221. def build_generation_sample(self, text, max_gen_length, use_task_mask, unidirectional):
  222. tokenizer = get_tokenizer()
  223. sop_id = tokenizer.get_command("sop")
  224. mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
  225. token = np.array(text, dtype=np.int64)
  226. position_id = np.arange(len(text), dtype=np.int64)
  227. block_position_id = np.zeros(len(text), dtype=np.int64)
  228. target_position_id = np.zeros(len(text), dtype=np.int64)
  229. target_block_position_id = np.zeros(len(text), dtype=np.int64)
  230. blank_filling = mask_id in text
  231. if unidirectional:
  232. assert use_task_mask, "Unidirectional attention only support gMASK"
  233. assert not blank_filling, "Unidirectional attention doesn't support blank filling"
  234. token = np.concatenate(([mask_id, sop_id], token))
  235. if self.position_encoding_2d:
  236. position_id = np.zeros(len(token), dtype=np.int64)
  237. target_position_id = np.zeros(max_gen_length, dtype=np.int64)
  238. block_position_id = np.arange(len(token), dtype=np.int64)
  239. target_block_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
  240. else:
  241. position_id = np.arange(len(token), dtype=np.int64)
  242. target_position_id = np.zeros(len(token), len(token) + max_gen_length, dtype=np.int64)
  243. else:
  244. if not blank_filling:
  245. mask_position = len(token)
  246. token = np.concatenate((token, [mask_id, sop_id]))
  247. else:
  248. assert not use_task_mask, "Blank filling only support MASK"
  249. mask_position = text.index(mask_id)
  250. token = np.concatenate((token, [sop_id]))
  251. position_id = np.concatenate((np.arange(len(token) - 1, dtype=np.int64), [mask_position]))
  252. target_position_id = np.full(max_gen_length, mask_position, dtype=np.int64)
  253. if self.position_encoding_2d:
  254. block_position_id = np.zeros(len(token), dtype=np.int64)
  255. target_block_position_id = np.arange(1, max_gen_length + 1, dtype=np.int64)
  256. elif use_task_mask:
  257. position_id = np.arange(len(token), dtype=np.int64)
  258. target_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
  259. context_length = len(token)
  260. attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
  261. if not unidirectional:
  262. attention_mask[: context_length - 1, : context_length - 1] = 1
  263. if self.position_encoding_2d:
  264. position_id = np.stack((position_id, block_position_id), axis=0)
  265. target_position_id = np.stack((target_position_id, target_block_position_id), axis=0)
  266. item = {
  267. "token": token,
  268. "position_id": position_id,
  269. "target_position_id": target_position_id,
  270. "attention_mask": attention_mask,
  271. "context_length": context_length,
  272. }
  273. return item
  274. def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[List[int]], List[List[List[int]]]]:
  275. """
  276. @return: A list of text model generated, sorted by score in descending order
  277. """
  278. seqs = sample["tokens"].to(device=self.device).long()
  279. context_lengths = sample["context_length"].long()
  280. def get_masks_and_position_ids(seq):
  281. batch_size = seq.shape[0]
  282. max_gen_length = sample["target_position_ids"].shape[-1]
  283. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
  284. position_ids = torch.cat((sample["position_ids"], sample["target_position_ids"]), dim=-1)
  285. position_ids = position_ids.to(device=self.device).long()
  286. attention_mask = sample["attention_mask"].to(device=self.device)
  287. context_mask = (
  288. attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1, max_gen_length, 1)
  289. )
  290. causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
  291. generation_mask = torch.cat((context_mask, causal_mask), dim=-1)
  292. attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode="constant", value=1)
  293. attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
  294. attention_mask = attention_mask.bool().unsqueeze(1)
  295. return tokens, attention_mask, position_ids
  296. self.model.eval()
  297. with torch.no_grad():
  298. output = batch_filling_sequence(
  299. self.model,
  300. seqs,
  301. context_lengths,
  302. get_masks_and_position_ids=get_masks_and_position_ids,
  303. strategy=strategy,
  304. )[0]
  305. if isinstance(output, torch.Tensor): # different strategies
  306. output = output.tolist()
  307. output_targets = []
  308. context_length = seqs.shape[1]
  309. for lines in output:
  310. lines = lines.tolist() if isinstance(lines, torch.Tensor) else lines
  311. output_target = []
  312. if not isinstance(lines, list):
  313. lines = [lines]
  314. for line in lines:
  315. unfinished = line.index(-1) if -1 in line else len(line)
  316. if line[unfinished - 1] in strategy.end_tokens:
  317. unfinished -= 1
  318. line = line[context_length:unfinished]
  319. output_target.append(line)
  320. if not return_all_beams:
  321. output_targets.append(output_target[0])
  322. else:
  323. output_targets.append(output_target)
  324. return output_targets
  325. def build_language_model_sample(
  326. self,
  327. tokens: List[int],
  328. is_first_segment: bool,
  329. max_seq_length: int,
  330. generation_length: int,
  331. unidirectional: bool,
  332. use_gmask: bool,
  333. ):
  334. tokenizer = get_tokenizer()
  335. sop_id = tokenizer.get_command("sop")
  336. mask_id = tokenizer.get_command("[gMASK]") if use_gmask else tokenizer.get_command("[MASK]")
  337. if is_first_segment or unidirectional:
  338. prompt, text = [], tokens
  339. else:
  340. prompt_length = max_seq_length - 1 - generation_length
  341. prompt, text = tokens[:prompt_length], tokens[prompt_length:]
  342. seq_length = len(prompt) + len(text) + 1
  343. attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
  344. attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
  345. gen_length = min(len(text), generation_length)
  346. position_id = np.arange(0, seq_length, dtype=np.int64)
  347. if self.position_encoding_2d:
  348. position_id = np.concatenate(
  349. (np.arange(0, seq_length - gen_length, dtype=np.int64), [seq_length - gen_length - 1] * gen_length)
  350. )
  351. block_position_id = np.concatenate(
  352. ([0] * (seq_length - gen_length - 1), np.arange(0, gen_length + 1, dtype=np.int64))
  353. )
  354. position_id = np.stack((position_id, block_position_id), axis=0)
  355. return {
  356. "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
  357. "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
  358. "position_ids": position_id,
  359. "attention_mask": attention_mask < 0.5,
  360. "loss_masks": np.array(
  361. [0] * (seq_length - gen_length) + [1] * gen_length,
  362. dtype=np.int64,
  363. ),
  364. }
  365. def calculate_loss(self, batch) -> List[float]:
  366. tokens, position_ids, attention_mask = self.process_data(batch, self.device)
  367. targets, loss_masks = (
  368. batch["targets"].to(device=self.device).long(),
  369. batch["loss_masks"].to(device=self.device).long(),
  370. )
  371. original_parallel_output = self.model.transformer.parallel_output
  372. self.model.transformer.parallel_output = True
  373. self.model.eval()
  374. with torch.no_grad():
  375. logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
  376. losses = vocab_parallel_cross_entropy(logits.contiguous().float(), targets)
  377. loss = torch.sum(losses * loss_masks, dim=-1)
  378. self.model.transformer.parallel_output = original_parallel_output
  379. return loss.tolist()