strategies.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
  5. class BaseStrategy:
  6. def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
  7. self.batch_size = batch_size
  8. self.invalid_slices = invalid_slices
  9. self.temperature = temperature
  10. self.topk = top_k
  11. self.top_p = top_p
  12. self.eps = eps
  13. if end_tokens is None:
  14. end_tokens = []
  15. self.end_tokens = end_tokens
  16. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  17. @property
  18. def is_done(self) -> bool:
  19. return self._is_done.all()
  20. def forward(self, logits, tokens, mems, temperature=None):
  21. batch_size = tokens.shape[0]
  22. if temperature is None:
  23. temperature = self.temperature
  24. logits = logits / temperature
  25. for invalid_slice in self.invalid_slices:
  26. logits[..., invalid_slice] = -65504
  27. # logits = top_k_logits(logits, self.topk, self.top_p)
  28. # probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
  29. # pred = torch.multinomial(probs, num_samples=1)
  30. pred = torch.argmax(logits, dim=-1)
  31. for i in range(self.batch_size):
  32. if i >= batch_size:
  33. self._is_done[i] = True
  34. elif self._is_done[i]:
  35. pred[i] = -1
  36. elif pred[i].item() in self.end_tokens:
  37. self._is_done[i] = True
  38. tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
  39. return tokens, mems
  40. def finalize(self, tokens, mems):
  41. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  42. return tokens, mems
  43. class BeamSearchStrategy:
  44. def __init__(
  45. self,
  46. batch_size,
  47. num_beams,
  48. length_penalty=1.0,
  49. consider_end=False,
  50. end_tokens=[],
  51. invalid_slices=[],
  52. no_repeat_ngram_size=0,
  53. min_gen_length=0,
  54. deterministic=False,
  55. ):
  56. self.batch_size = batch_size
  57. self.num_beams = num_beams
  58. self.length_penalty = length_penalty
  59. self.end_tokens = end_tokens
  60. self.ngram = no_repeat_ngram_size
  61. self.min_gen_length = min_gen_length
  62. self.invalid_slices = invalid_slices
  63. self.consider_end = consider_end
  64. self.deterministic = deterministic
  65. self._init_cache()
  66. def _init_cache(self):
  67. self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
  68. self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
  69. self.cached_beam_scores = 0 # [batch_size]
  70. self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
  71. self.length_generated = 0
  72. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  73. def _add_end_beams(self, score, beam, batch_idx):
  74. score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
  75. for i in range(len(self.end_beams), -1, -1):
  76. if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
  77. break
  78. self.end_beams[batch_idx].insert(i, beam)
  79. self.end_beams_penalized_scores[batch_idx].insert(i, score)
  80. self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
  81. self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
  82. def forward(self, logits, tokens, mems):
  83. if len(logits.shape) == 2:
  84. logits = logits.unsqueeze(1)
  85. tokens = tokens.unsqueeze(1)
  86. mems = mems.unsqueeze(2)
  87. batch_size, num_beams, vocab_size = logits.shape
  88. seq_len = tokens.shape[-1]
  89. logits = logits.float()
  90. for invalid_slice in self.invalid_slices:
  91. logits[..., invalid_slice] = -65504
  92. if self.min_gen_length > self.length_generated:
  93. for end_token in self.end_tokens:
  94. logits[..., end_token] = -65504
  95. if self.ngram > 0 and seq_len > self.ngram:
  96. for batch_idx in range(batch_size):
  97. for i in range(num_beams):
  98. ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
  99. for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
  100. logits[batch_idx, i, banned_index] = -65504
  101. next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
  102. prev_scores = self.cached_beam_scores
  103. if isinstance(prev_scores, torch.Tensor):
  104. prev_scores = prev_scores[..., None].expand_as(next_token_scores)
  105. next_token_scores = next_token_scores + prev_scores
  106. next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
  107. probs = F.softmax(next_token_scores, dim=-1)
  108. if self.deterministic:
  109. if mems.shape[2] < self.num_beams: # First token
  110. probs = probs[..., :vocab_size]
  111. next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
  112. else:
  113. next_tokens = torch.multinomial(
  114. probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
  115. ) # [2*nb]
  116. next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
  117. next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
  118. next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
  119. next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
  120. next_tokens = next_tokens % vocab_size
  121. # select out end beams or continue beams
  122. if mems.shape[2] < self.num_beams:
  123. mems = mems.expand(-1, batch_size, self.num_beams, -1, -1)
  124. beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
  125. for batch_idx in range(batch_size):
  126. beam_continue = []
  127. scores_continue = []
  128. bans_continue = []
  129. mems_contiue = []
  130. for i in range(len(next_tokens[batch_idx])):
  131. beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
  132. if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
  133. self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
  134. elif len(beam_continue) < self.num_beams:
  135. beam_continue.append(beam)
  136. mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
  137. # update caches
  138. scores_continue.append(next_token_scores[batch_idx, i])
  139. if self.ngram > 0:
  140. bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
  141. # TODO ngram=1
  142. ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
  143. bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
  144. bans_continue.append(bans)
  145. else:
  146. break
  147. beam_continue_batch.append(torch.stack(beam_continue))
  148. mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
  149. score_continue_batch.append(scores_continue)
  150. self.cached_beam_ngram_bans[batch_idx] = bans_continue
  151. tokens = torch.stack(beam_continue_batch)
  152. mems = torch.stack(mems_continue_batch)
  153. self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
  154. self.length_generated += 1
  155. for batch_idx in range(self.batch_size):
  156. if batch_idx >= batch_size:
  157. self._is_done[batch_idx] = True
  158. elif (
  159. len(self.end_beams[batch_idx]) == self.num_beams
  160. and self.end_beams_penalized_scores[batch_idx][-1]
  161. >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
  162. ): # We're done if none of current tokens will better than the worst in end_beams
  163. self._is_done[batch_idx] = True
  164. return tokens, mems
  165. def finalize(self, tokens, mems):
  166. if self.consider_end:
  167. for batch_idx in range(tokens.shape[0]):
  168. if not self._is_done[batch_idx]:
  169. for i in range(tokens.shape[0]):
  170. self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[i], batch_idx)
  171. mems = None
  172. ret = self.end_beams
  173. else:
  174. ret = tokens
  175. self._init_cache()
  176. return ret, mems