strategies.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
  5. class BaseStrategy:
  6. def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None,
  7. deterministic=False):
  8. self.batch_size = batch_size
  9. self.invalid_slices = invalid_slices
  10. self.temperature = temperature
  11. self.topk = top_k
  12. self.top_p = top_p
  13. self.eps = eps
  14. if end_tokens is None:
  15. end_tokens = []
  16. self.end_tokens = end_tokens
  17. self.deterministic = deterministic
  18. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  19. @property
  20. def is_done(self) -> bool:
  21. return self._is_done.all()
  22. def forward(self, logits, tokens, mems, temperature=None):
  23. logits = logits.view(-1, logits.size(-1))
  24. batch_size = tokens.shape[0]
  25. if temperature is None:
  26. temperature = self.temperature
  27. logits = logits / temperature
  28. for invalid_slice in self.invalid_slices:
  29. logits[..., invalid_slice] = -65504
  30. logits = top_k_logits(logits, self.topk, self.top_p)
  31. if self.deterministic:
  32. pred = logits.max(dim=-1)[1]
  33. else:
  34. probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
  35. pred = torch.multinomial(probs, num_samples=1)
  36. for i in range(self.batch_size):
  37. if i >= batch_size:
  38. self._is_done[i] = True
  39. elif self._is_done[i]:
  40. pred[i] = -1
  41. elif pred[i].item() in self.end_tokens:
  42. self._is_done[i] = True
  43. tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
  44. return tokens, mems
  45. def finalize(self, tokens, mems):
  46. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  47. return tokens, mems
  48. class BeamSearchStrategy:
  49. def __init__(
  50. self,
  51. batch_size,
  52. num_beams,
  53. length_penalty=1.0,
  54. consider_end=False,
  55. end_tokens=[],
  56. invalid_slices=[],
  57. no_repeat_ngram_size=0,
  58. min_gen_length=0,
  59. deterministic=False,
  60. ):
  61. self.batch_size = batch_size
  62. self.num_beams = num_beams
  63. self.length_penalty = length_penalty
  64. self.end_tokens = end_tokens
  65. self.ngram = no_repeat_ngram_size
  66. self.min_gen_length = min_gen_length
  67. self.invalid_slices = invalid_slices
  68. self.consider_end = consider_end
  69. self.deterministic = deterministic
  70. self._init_cache()
  71. def _init_cache(self):
  72. self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
  73. self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
  74. self.cached_beam_scores = 0 # [batch_size]
  75. self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
  76. self.length_generated = 0
  77. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  78. def _add_end_beams(self, score, beam, batch_idx):
  79. score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
  80. for i in range(len(self.end_beams[batch_idx]), -1, -1):
  81. if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
  82. break
  83. self.end_beams[batch_idx].insert(i, beam)
  84. self.end_beams_penalized_scores[batch_idx].insert(i, score)
  85. self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
  86. self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
  87. @property
  88. def is_done(self) -> bool:
  89. return self._is_done.all()
  90. def forward(self, logits, tokens, mems):
  91. batch_size, num_beams, vocab_size = logits.shape
  92. seq_len = tokens.shape[-1]
  93. logits = logits.float()
  94. for invalid_slice in self.invalid_slices:
  95. logits[..., invalid_slice] = -65504
  96. if self.min_gen_length > self.length_generated:
  97. for end_token in self.end_tokens:
  98. logits[..., end_token] = -65504
  99. if self.ngram > 0 and seq_len > self.ngram:
  100. for batch_idx in range(batch_size):
  101. for i in range(num_beams):
  102. ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
  103. for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
  104. logits[batch_idx, i, banned_index] = -65504
  105. next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
  106. prev_scores = self.cached_beam_scores
  107. if isinstance(prev_scores, torch.Tensor):
  108. prev_scores = prev_scores[..., None].expand_as(next_token_scores)
  109. next_token_scores = next_token_scores + prev_scores
  110. next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
  111. probs = F.softmax(next_token_scores, dim=-1)
  112. if num_beams < self.num_beams: # First token
  113. probs = probs[..., :vocab_size]
  114. if self.deterministic:
  115. next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
  116. else:
  117. next_tokens = torch.multinomial(
  118. probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
  119. ) # [2*nb]
  120. next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
  121. next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
  122. next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
  123. next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
  124. next_tokens = next_tokens % vocab_size
  125. # select out end beams or continue beams
  126. beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
  127. for batch_idx in range(batch_size):
  128. beam_continue = []
  129. scores_continue = []
  130. bans_continue = []
  131. mems_contiue = []
  132. for i in range(len(next_tokens[batch_idx])):
  133. beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
  134. if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
  135. self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
  136. elif len(beam_continue) < self.num_beams:
  137. beam_continue.append(beam)
  138. mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
  139. # update caches
  140. scores_continue.append(next_token_scores[batch_idx, i])
  141. if self.ngram > 0:
  142. bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
  143. # TODO ngram=1
  144. ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
  145. bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
  146. bans_continue.append(bans)
  147. else:
  148. break
  149. beam_continue_batch.append(torch.stack(beam_continue))
  150. mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
  151. score_continue_batch.append(scores_continue)
  152. self.cached_beam_ngram_bans[batch_idx] = bans_continue
  153. tokens = torch.stack(beam_continue_batch)
  154. mems = torch.stack(mems_continue_batch, dim=1)
  155. self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
  156. self.length_generated += 1
  157. for batch_idx in range(self.batch_size):
  158. if batch_idx >= batch_size:
  159. self._is_done[batch_idx] = True
  160. elif (
  161. len(self.end_beams[batch_idx]) == self.num_beams
  162. and self.end_beams_penalized_scores[batch_idx][-1]
  163. >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
  164. ): # We're done if none of current tokens will better than the worst in end_beams
  165. self._is_done[batch_idx] = True
  166. return tokens, mems
  167. def finalize(self, tokens, mems):
  168. if self.consider_end:
  169. batch_size, num_beams = tokens.shape[:2]
  170. for batch_idx in range(batch_size):
  171. if not self._is_done[batch_idx]:
  172. for i in range(num_beams):
  173. self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
  174. mems = None
  175. ret = self.end_beams[:batch_size]
  176. else:
  177. ret = tokens
  178. self._init_cache()
  179. return ret, mems