strategies.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
  5. class BaseStrategy:
  6. def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
  7. self.batch_size = batch_size
  8. self.invalid_slices = invalid_slices
  9. self.temperature = temperature
  10. self.topk = top_k
  11. self.top_p = top_p
  12. self.eps = eps
  13. if end_tokens is None:
  14. end_tokens = []
  15. self.end_tokens = end_tokens
  16. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  17. @property
  18. def is_done(self) -> bool:
  19. return self._is_done.all()
  20. def forward(self, logits, tokens, mems, temperature=None):
  21. logits = logits.view(-1, logits.size(-1))
  22. batch_size = tokens.shape[0]
  23. if temperature is None:
  24. temperature = self.temperature
  25. logits = logits / temperature
  26. for invalid_slice in self.invalid_slices:
  27. logits[..., invalid_slice] = -65504
  28. logits = top_k_logits(logits, self.topk, self.top_p)
  29. probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
  30. pred = torch.multinomial(probs, num_samples=1)
  31. for i in range(self.batch_size):
  32. if i >= batch_size:
  33. self._is_done[i] = True
  34. elif self._is_done[i]:
  35. pred[i] = -1
  36. elif pred[i].item() in self.end_tokens:
  37. self._is_done[i] = True
  38. tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
  39. return tokens, mems
  40. def finalize(self, tokens, mems):
  41. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  42. return tokens, mems
  43. class BeamSearchStrategy:
  44. def __init__(
  45. self,
  46. batch_size,
  47. num_beams,
  48. length_penalty=1.0,
  49. consider_end=False,
  50. end_tokens=[],
  51. invalid_slices=[],
  52. no_repeat_ngram_size=0,
  53. min_gen_length=0,
  54. deterministic=False,
  55. ):
  56. self.batch_size = batch_size
  57. self.num_beams = num_beams
  58. self.length_penalty = length_penalty
  59. self.end_tokens = end_tokens
  60. self.ngram = no_repeat_ngram_size
  61. self.min_gen_length = min_gen_length
  62. self.invalid_slices = invalid_slices
  63. self.consider_end = consider_end
  64. self.deterministic = deterministic
  65. self._init_cache()
  66. def _init_cache(self):
  67. self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
  68. self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
  69. self.cached_beam_scores = 0 # [batch_size]
  70. self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
  71. self.length_generated = 0
  72. self._is_done = np.zeros(self.batch_size, dtype=np.bool)
  73. def _add_end_beams(self, score, beam, batch_idx):
  74. score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
  75. for i in range(len(self.end_beams[batch_idx]), -1, -1):
  76. if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
  77. break
  78. self.end_beams[batch_idx].insert(i, beam)
  79. self.end_beams_penalized_scores[batch_idx].insert(i, score)
  80. self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
  81. self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
  82. @property
  83. def is_done(self) -> bool:
  84. return self._is_done.all()
  85. def forward(self, logits, tokens, mems):
  86. batch_size, num_beams, vocab_size = logits.shape
  87. seq_len = tokens.shape[-1]
  88. logits = logits.float()
  89. for invalid_slice in self.invalid_slices:
  90. logits[..., invalid_slice] = -65504
  91. if self.min_gen_length > self.length_generated:
  92. for end_token in self.end_tokens:
  93. logits[..., end_token] = -65504
  94. if self.ngram > 0 and seq_len > self.ngram:
  95. for batch_idx in range(batch_size):
  96. for i in range(num_beams):
  97. ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
  98. for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
  99. logits[batch_idx, i, banned_index] = -65504
  100. next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
  101. prev_scores = self.cached_beam_scores
  102. if isinstance(prev_scores, torch.Tensor):
  103. prev_scores = prev_scores[..., None].expand_as(next_token_scores)
  104. next_token_scores = next_token_scores + prev_scores
  105. next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
  106. probs = F.softmax(next_token_scores, dim=-1)
  107. if num_beams < self.num_beams: # First token
  108. probs = probs[..., :vocab_size]
  109. if self.deterministic:
  110. next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
  111. else:
  112. next_tokens = torch.multinomial(
  113. probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
  114. ) # [2*nb]
  115. next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
  116. next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
  117. next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
  118. next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
  119. next_tokens = next_tokens % vocab_size
  120. # select out end beams or continue beams
  121. beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
  122. for batch_idx in range(batch_size):
  123. beam_continue = []
  124. scores_continue = []
  125. bans_continue = []
  126. mems_contiue = []
  127. for i in range(len(next_tokens[batch_idx])):
  128. beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
  129. if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
  130. self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
  131. elif len(beam_continue) < self.num_beams:
  132. beam_continue.append(beam)
  133. mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
  134. # update caches
  135. scores_continue.append(next_token_scores[batch_idx, i])
  136. if self.ngram > 0:
  137. bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
  138. # TODO ngram=1
  139. ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
  140. bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
  141. bans_continue.append(bans)
  142. else:
  143. break
  144. beam_continue_batch.append(torch.stack(beam_continue))
  145. mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
  146. score_continue_batch.append(scores_continue)
  147. self.cached_beam_ngram_bans[batch_idx] = bans_continue
  148. tokens = torch.stack(beam_continue_batch)
  149. mems = torch.stack(mems_continue_batch, dim=1)
  150. self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
  151. self.length_generated += 1
  152. for batch_idx in range(self.batch_size):
  153. if batch_idx >= batch_size:
  154. self._is_done[batch_idx] = True
  155. elif (
  156. len(self.end_beams[batch_idx]) == self.num_beams
  157. and self.end_beams_penalized_scores[batch_idx][-1]
  158. >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
  159. ): # We're done if none of current tokens will better than the worst in end_beams
  160. self._is_done[batch_idx] = True
  161. return tokens, mems
  162. def finalize(self, tokens, mems):
  163. if self.consider_end:
  164. batch_size, num_beams = tokens.shape[:2]
  165. for batch_idx in range(batch_size):
  166. if not self._is_done[batch_idx]:
  167. for i in range(num_beams):
  168. self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
  169. mems = None
  170. ret = self.end_beams[:batch_size]
  171. else:
  172. ret = tokens
  173. self._init_cache()
  174. return ret, mems