model.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import torch
  2. from typing import List, Union
  3. from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
  4. from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
  5. def batch_filling_sequence(
  6. model,
  7. seqs,
  8. context_lengths,
  9. strategy,
  10. max_memory_length=100000,
  11. get_masks_and_position_ids=get_masks_and_position_ids_default,
  12. mems=None,
  13. **kw_args
  14. ):
  15. '''
  16. seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
  17. mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
  18. cache, should be first mems.shape[1] parts of context_tokens.
  19. mems are the first-level citizens here, but we don't assume what is memorized.
  20. input mems are used when multi-phase generation.
  21. '''
  22. assert len(seqs.shape) == 2
  23. # building the initial tokens, attention_mask, and position_ids
  24. batch_size, context_length = seqs.shape
  25. seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
  26. tokens = seqs[..., :context_length]
  27. if attention_mask.dtype != torch.bool:
  28. attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
  29. # initialize generation
  30. counter = context_length - 1 # Last fixed index is ``counter''
  31. index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
  32. num_beams = 1
  33. # step-by-step generation
  34. while counter < seqs.shape[1] - 1:
  35. # Now, we want to generate seq[counter + 1],
  36. # token[:, index: counter+1] needs forwarding.
  37. # forward
  38. if num_beams > 1:
  39. tokens = tokens.reshape(batch_size * num_beams, -1)
  40. mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
  41. logits, *output_per_layers = model(
  42. tokens[:, index:],
  43. position_ids[..., index: counter+1],
  44. attention_mask[..., index: counter+1, :counter+1], # TODO memlen
  45. mems=mems,
  46. **kw_args
  47. )
  48. mem_kv = [o['mem_kv'] for o in output_per_layers]
  49. mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
  50. if counter == context_length - 1:
  51. logits = logits[torch.arange(batch_size), context_lengths - 1]
  52. else:
  53. logits = logits[:, -1]
  54. counter += 1
  55. index = counter
  56. # sampling
  57. if num_beams > 1:
  58. logits = logits.reshape(batch_size, num_beams, -1)
  59. tokens = tokens.reshape(batch_size, num_beams, -1)
  60. mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
  61. tokens, mems = strategy.forward(logits, tokens, mems)
  62. if len(tokens.shape) == 3 and num_beams == 1:
  63. num_beams = tokens.shape[1]
  64. position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
  65. attention_mask_shape = attention_mask.shape[-3:]
  66. attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
  67. batch_size * num_beams, *attention_mask_shape)
  68. if strategy.is_done:
  69. break
  70. return strategy.finalize(tokens, mems)
  71. class ModelForEvaluation(torch.nn.Module):
  72. def __init__(self, model):
  73. super().__init__()
  74. self.model = model
  75. @staticmethod
  76. def process_data(batch):
  77. return (
  78. batch["tokens"].to(device=torch.cuda.current_device()).long(),
  79. batch["position_ids"].to(device=torch.cuda.current_device()).long(),
  80. batch["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1),
  81. )
  82. def cond_log_prob(self, batch) -> List[List[float]]:
  83. """
  84. @return: Conditional log probability of each option
  85. """
  86. tokens, position_ids, attention_mask = self.process_data(batch)
  87. choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
  88. is_single_token = batch["is_single_token"]
  89. self.model.eval()
  90. with torch.no_grad():
  91. logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
  92. logits_batch = torch.nn.functional.log_softmax(logits, dim=-1)
  93. # output: [b, sq, vocab]
  94. log_probs = []
  95. if is_single_token: # Single token
  96. for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  97. log_probs.append(logits[choice_target_ids[0], choices].tolist())
  98. else: # Multi token
  99. for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  100. log_probs_single = []
  101. for choice, choice_target_id in zip(choices, choice_target_ids):
  102. tmp = output[choice_target_id, choice]
  103. log_probs_single.append(tmp.sum().tolist())
  104. log_probs.append(log_probs_single)
  105. return log_probs
  106. def generate_text(self, sample, strategy, return_all_beams=False, max_gen_length=128) -> Union[
  107. List[int], List[List[int]]]:
  108. """
  109. @return: A list of text model generated, sorted by score in descending order
  110. """
  111. seqs = sample["tokens"].to(device=torch.cuda.current_device()).long()
  112. context_lengths = sample["context_length"].long()
  113. def get_masks_and_position_ids(seq):
  114. batch_size = seq.shape[0]
  115. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
  116. position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
  117. position_ids = position_ids.to(device=torch.cuda.current_device()).long()
  118. attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device())
  119. context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
  120. max_gen_length,
  121. 1)
  122. causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
  123. generation_mask = torch.cat(
  124. (context_mask, causal_mask), dim=-1)
  125. attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode='constant', value=1)
  126. attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
  127. attention_mask = attention_mask.bool().unsqueeze(1)
  128. return tokens, attention_mask, position_ids
  129. self.model.eval()
  130. with torch.no_grad():
  131. output = batch_filling_sequence(
  132. self.model,
  133. seqs,
  134. context_lengths,
  135. get_masks_and_position_ids=get_masks_and_position_ids,
  136. strategy=strategy,
  137. )[0]
  138. if isinstance(output, torch.Tensor): # different strategies
  139. output = list(output)
  140. output_targets = []
  141. context_length = seqs.shape[1]
  142. for lines in output:
  143. output_target = []
  144. if not isinstance(lines, list):
  145. lines = [lines]
  146. for line in lines:
  147. line = line.tolist()
  148. unfinished = line.index(-1) if -1 in line else len(line)
  149. if line[unfinished - 1] in strategy.end_tokens:
  150. unfinished -= 1
  151. line = line[context_length:unfinished]
  152. output_target.append(line)
  153. if not return_all_beams:
  154. output_targets.append(output_target[0])
  155. else:
  156. output_targets.append(output_target)
  157. return output_targets
  158. def calculate_loss(self, batch) -> List[float]:
  159. tokens, position_ids, attention_mask = self.process_data(batch)
  160. targets, loss_masks = (
  161. batch["targets"].to(device=torch.cuda.current_device()).long(),
  162. batch["loss_masks"].to(device=torch.cuda.current_device()).long(),
  163. )
  164. original_parallel_output = self.model.transformer.parallel_output
  165. self.model.transformer.parallel_output = True
  166. self.model.eval()
  167. with torch.no_grad():
  168. logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
  169. losses = vocab_parallel_cross_entropy(logits.contiguous().float(), targets)
  170. loss = torch.sum(losses * loss_masks, dim=-1)
  171. self.model.transformer.parallel_output = original_parallel_output
  172. # return list(zip(loss.tolist(), loss_masks.sum(dim=-1).tolist()))
  173. return loss.tolist()