|
@@ -81,7 +81,7 @@ class BeamSearchStrategy:
|
|
|
|
|
|
def _add_end_beams(self, score, beam, batch_idx):
|
|
|
score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
|
|
|
- for i in range(len(self.end_beams), -1, -1):
|
|
|
+ for i in range(len(self.end_beams[batch_idx]), -1, -1):
|
|
|
if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
|
|
|
break
|
|
|
self.end_beams[batch_idx].insert(i, beam)
|
|
@@ -90,6 +90,10 @@ class BeamSearchStrategy:
|
|
|
self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
|
|
|
self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
|
|
|
|
|
|
+ @property
|
|
|
+ def is_done(self) -> bool:
|
|
|
+ return self._is_done.all()
|
|
|
+
|
|
|
def forward(self, logits, tokens, mems):
|
|
|
if len(logits.shape) == 2:
|
|
|
logits = logits.unsqueeze(1)
|
|
@@ -135,8 +139,6 @@ class BeamSearchStrategy:
|
|
|
next_tokens = next_tokens % vocab_size
|
|
|
|
|
|
# select out end beams or continue beams
|
|
|
- if mems.shape[2] < self.num_beams:
|
|
|
- mems = mems.expand(-1, batch_size, self.num_beams, -1, -1)
|
|
|
beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
|
|
|
for batch_idx in range(batch_size):
|
|
|
beam_continue = []
|
|
@@ -156,7 +158,7 @@ class BeamSearchStrategy:
|
|
|
bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
|
|
|
# TODO ngram=1
|
|
|
ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
|
|
|
- bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
|
|
|
+ bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
|
|
|
bans_continue.append(bans)
|
|
|
else:
|
|
|
break
|
|
@@ -165,7 +167,7 @@ class BeamSearchStrategy:
|
|
|
score_continue_batch.append(scores_continue)
|
|
|
self.cached_beam_ngram_bans[batch_idx] = bans_continue
|
|
|
tokens = torch.stack(beam_continue_batch)
|
|
|
- mems = torch.stack(mems_continue_batch)
|
|
|
+ mems = torch.stack(mems_continue_batch, dim=1)
|
|
|
self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
|
|
|
self.length_generated += 1
|
|
|
for batch_idx in range(self.batch_size):
|
|
@@ -182,12 +184,13 @@ class BeamSearchStrategy:
|
|
|
|
|
|
def finalize(self, tokens, mems):
|
|
|
if self.consider_end:
|
|
|
- for batch_idx in range(tokens.shape[0]):
|
|
|
+ batch_size = tokens.shape[0]
|
|
|
+ for batch_idx in range(batch_size):
|
|
|
if not self._is_done[batch_idx]:
|
|
|
- for i in range(tokens.shape[0]):
|
|
|
+ for i in range(batch_size):
|
|
|
self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[i], batch_idx)
|
|
|
mems = None
|
|
|
- ret = self.end_beams
|
|
|
+ ret = self.end_beams[:batch_size]
|
|
|
else:
|
|
|
ret = tokens
|
|
|
self._init_cache()
|