strategies.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
  5. class BaseStrategy:
  6. def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
  7. self.batch_size = batch_size
  8. self.invalid_slices = invalid_slices
  9. self.temperature = temperature
  10. self.topk = top_k
  11. self.top_p = top_p
  12. self.eps = eps
  13. if end_tokens is None:
  14. end_tokens = []
  15. self.end_tokens = end_tokens
  16. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  17. @property
  18. def is_done(self) -> bool:
  19. return self._is_done.all()
  20. def forward(self, logits, tokens, mems, temperature=None):
  21. batch_size = tokens.shape[0]
  22. if temperature is None:
  23. temperature = self.temperature
  24. logits = logits / temperature
  25. for invalid_slice in self.invalid_slices:
  26. logits[..., invalid_slice] = -65504
  27. logits = top_k_logits(logits, self.topk, self.top_p)
  28. probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
  29. pred = torch.multinomial(probs, num_samples=1)
  30. for i in range(self.batch_size):
  31. if i >= batch_size:
  32. self._is_done[i] = True
  33. elif self._is_done[i]:
  34. pred[i] = -1
  35. elif pred[i].item() in self.end_tokens:
  36. self._is_done[i] = True
  37. tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
  38. return tokens, mems
  39. def finalize(self, tokens, mems):
  40. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  41. return tokens, mems
  42. class BeamSearchStrategy:
  43. def __init__(
  44. self,
  45. batch_size,
  46. num_beams,
  47. length_penalty=1.0,
  48. consider_end=False,
  49. end_tokens=[],
  50. invalid_slices=[],
  51. no_repeat_ngram_size=0,
  52. min_gen_length=0,
  53. deterministic=False,
  54. ):
  55. self.batch_size = batch_size
  56. self.num_beams = num_beams
  57. self.length_penalty = length_penalty
  58. self.end_tokens = end_tokens
  59. self.ngram = no_repeat_ngram_size
  60. self.min_gen_length = min_gen_length
  61. self.invalid_slices = invalid_slices
  62. self.consider_end = consider_end
  63. self.deterministic = deterministic
  64. self._init_cache()
  65. def _init_cache(self):
  66. self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
  67. self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
  68. self.cached_beam_scores = 0 # [batch_size]
  69. self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
  70. self.length_generated = 0
  71. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  72. def _add_end_beams(self, score, beam, batch_idx):
  73. score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
  74. for i in range(len(self.end_beams[batch_idx]), -1, -1):
  75. if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
  76. break
  77. self.end_beams[batch_idx].insert(i, beam)
  78. self.end_beams_penalized_scores[batch_idx].insert(i, score)
  79. self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
  80. self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
  81. @property
  82. def is_done(self) -> bool:
  83. return self._is_done.all()
  84. def forward(self, logits, tokens, mems):
  85. if len(logits.shape) == 2:
  86. logits = logits.unsqueeze(1)
  87. tokens = tokens.unsqueeze(1)
  88. mems = mems.unsqueeze(2)
  89. batch_size, num_beams, vocab_size = logits.shape
  90. seq_len = tokens.shape[-1]
  91. logits = logits.float()
  92. for invalid_slice in self.invalid_slices:
  93. logits[..., invalid_slice] = -65504
  94. if self.min_gen_length > self.length_generated:
  95. for end_token in self.end_tokens:
  96. logits[..., end_token] = -65504
  97. if self.ngram > 0 and seq_len > self.ngram:
  98. for batch_idx in range(batch_size):
  99. for i in range(num_beams):
  100. ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
  101. for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
  102. logits[batch_idx, i, banned_index] = -65504
  103. next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
  104. prev_scores = self.cached_beam_scores
  105. if isinstance(prev_scores, torch.Tensor):
  106. prev_scores = prev_scores[..., None].expand_as(next_token_scores)
  107. next_token_scores = next_token_scores + prev_scores
  108. next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
  109. probs = F.softmax(next_token_scores, dim=-1)
  110. if self.deterministic:
  111. if num_beams < self.num_beams: # First token
  112. probs = probs[..., :vocab_size]
  113. next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
  114. else:
  115. next_tokens = torch.multinomial(
  116. probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
  117. ) # [2*nb]
  118. next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
  119. next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
  120. next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
  121. next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
  122. next_tokens = next_tokens % vocab_size
  123. # select out end beams or continue beams
  124. beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
  125. for batch_idx in range(batch_size):
  126. beam_continue = []
  127. scores_continue = []
  128. bans_continue = []
  129. mems_contiue = []
  130. for i in range(len(next_tokens[batch_idx])):
  131. beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
  132. if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
  133. self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
  134. elif len(beam_continue) < self.num_beams:
  135. beam_continue.append(beam)
  136. mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
  137. # update caches
  138. scores_continue.append(next_token_scores[batch_idx, i])
  139. if self.ngram > 0:
  140. bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
  141. # TODO ngram=1
  142. ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
  143. bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
  144. bans_continue.append(bans)
  145. else:
  146. break
  147. beam_continue_batch.append(torch.stack(beam_continue))
  148. mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
  149. score_continue_batch.append(scores_continue)
  150. self.cached_beam_ngram_bans[batch_idx] = bans_continue
  151. tokens = torch.stack(beam_continue_batch)
  152. mems = torch.stack(mems_continue_batch, dim=1)
  153. self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
  154. self.length_generated += 1
  155. for batch_idx in range(self.batch_size):
  156. if batch_idx >= batch_size:
  157. self._is_done[batch_idx] = True
  158. elif (
  159. len(self.end_beams[batch_idx]) == self.num_beams
  160. and self.end_beams_penalized_scores[batch_idx][-1]
  161. >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
  162. ): # We're done if none of current tokens will better than the worst in end_beams
  163. self._is_done[batch_idx] = True
  164. return tokens, mems
  165. def finalize(self, tokens, mems):
  166. if self.consider_end:
  167. batch_size, num_beams = tokens.shape[:2]
  168. for batch_idx in range(batch_size):
  169. if not self._is_done[batch_idx]:
  170. for i in range(num_beams):
  171. self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
  172. mems = None
  173. ret = self.end_beams[:batch_size]
  174. else:
  175. ret = tokens
  176. self._init_cache()
  177. return ret, mems