strategies.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import torch
  2. import torch.nn.functional as F
  3. class BeamSearchStrategy:
  4. def __init__(
  5. self,
  6. num_beams,
  7. length_penalty=1.0,
  8. consider_end=False,
  9. end_tokens=[],
  10. invalid_slices=[],
  11. no_repeat_ngram_size=0,
  12. min_gen_length=0,
  13. deterministic=False,
  14. ):
  15. self.num_beams = num_beams
  16. self.length_penalty = length_penalty
  17. self.end_tokens = end_tokens
  18. self.ngram = no_repeat_ngram_size
  19. self.min_gen_length = min_gen_length
  20. self.invalid_slices = invalid_slices
  21. self.consider_end = consider_end
  22. self.deterministic = deterministic
  23. self._init_cache()
  24. def _init_cache(self):
  25. self.end_beams = [] # list of LongTensors
  26. self.end_beams_penalized_scores = [] # list of LongTensors
  27. self.cached_beam_scores = 0 # [batch_size]
  28. self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
  29. self.length_generated = 0
  30. self.is_done = False
  31. def _add_end_beams(self, score, beam):
  32. score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
  33. for i in range(len(self.end_beams), -1, -1):
  34. if i == 0 or score < self.end_beams_penalized_scores[i - 1]:
  35. break
  36. self.end_beams.insert(i, beam)
  37. self.end_beams_penalized_scores.insert(i, score)
  38. self.end_beams = self.end_beams[: self.num_beams]
  39. self.end_beams_penalized_scores = self.end_beams_penalized_scores[: self.num_beams]
  40. def forward(self, logits, tokens, mems):
  41. batch_size, vocab_size = logits.shape
  42. seq_len = tokens.shape[-1]
  43. logits = logits.float()
  44. for invalid_slice in self.invalid_slices:
  45. logits[..., invalid_slice] = -65504
  46. if self.min_gen_length > self.length_generated:
  47. for end_token in self.end_tokens:
  48. logits[..., end_token] = -65504
  49. if self.ngram > 0 and seq_len > self.ngram:
  50. for i in range(batch_size):
  51. ngram_prefix = tokens[i, -(self.ngram - 1) :].tolist() # TODO ngram=1
  52. for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
  53. logits[i, banned_index] = -65504
  54. next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
  55. prev_scores = self.cached_beam_scores
  56. if isinstance(self.cached_beam_scores, torch.Tensor):
  57. prev_scores = prev_scores[:, None].expand_as(next_token_scores)
  58. next_token_scores = next_token_scores + prev_scores
  59. next_token_scores = next_token_scores.view(batch_size * vocab_size)
  60. probs = F.softmax(next_token_scores, dim=0)
  61. if self.deterministic:
  62. if mems.shape[1] < batch_size: # First token
  63. probs = probs[:vocab_size]
  64. next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
  65. else:
  66. next_tokens = torch.multinomial(
  67. probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
  68. ) # [2*nb]
  69. next_token_scores = next_token_scores[next_tokens]
  70. next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
  71. next_tokens = next_tokens[_indices]
  72. next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
  73. next_tokens = next_tokens % vocab_size
  74. # select out end beams or continue beams
  75. if mems.shape[1] < batch_size:
  76. mems = mems.expand(-1, batch_size, -1, -1)
  77. beam_continue = []
  78. scores_continue = []
  79. bans_continue = []
  80. mems_contiue = []
  81. for i in range(len(next_tokens)):
  82. beam = torch.cat((tokens[next_indices[i]], next_tokens[i : i + 1]))
  83. if int(next_tokens[i]) in self.end_tokens:
  84. self._add_end_beams(next_token_scores[i], beam)
  85. elif len(beam_continue) < self.num_beams:
  86. beam_continue.append(beam)
  87. mems_contiue.append(mems[:, next_indices[i]])
  88. # update caches
  89. scores_continue.append(next_token_scores[i])
  90. if self.ngram > 0:
  91. bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
  92. ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram - 1) :].tolist()) # TODO ngram=1
  93. bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
  94. bans_continue.append(bans)
  95. else:
  96. break
  97. tokens = torch.stack(beam_continue)
  98. mems = torch.stack(mems_contiue, dim=1)
  99. self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
  100. self.cached_beam_ngram_bans = bans_continue
  101. self.length_generated += 1
  102. if (
  103. len(self.end_beams) == self.num_beams
  104. and self.end_beams_penalized_scores[-1]
  105. >= self.cached_beam_scores.max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
  106. ): # We're done if none of current tokens will better than the worst in end_beams
  107. self.is_done = True
  108. return tokens, mems
  109. def finalize(self, tokens, mems):
  110. if self.consider_end:
  111. for i in range(tokens.shape[0]):
  112. self._add_end_beams(self.cached_beam_scores[i], tokens[i])
  113. mems = None
  114. ret = self.end_beams
  115. else:
  116. ret = tokens
  117. self._init_cache()
  118. return ret, mems