Browse Source

Fix BeamSeachStartegy

Zhengxiao Du 3 years ago
parent
commit
223c40b636
2 changed files with 29 additions and 21 deletions
  1. 18 13
      evaluation/model.py
  2. 11 8
      generation/strategies.py

+ 18 - 13
evaluation/model.py

@@ -55,9 +55,6 @@ def batch_filling_sequence(
             logits = logits[torch.arange(batch_size), context_lengths - 1]
         else:
             logits = logits[:, -1]
-        # if torch.distributed.get_rank() == 0:
-        #     breakpoint()
-        # torch.distributed.barrier()
         counter += 1
         index = counter
         # sampling
@@ -66,11 +63,12 @@ def batch_filling_sequence(
             tokens = tokens.reshape(batch_size, num_beams, -1)
             mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
         tokens, mems = strategy.forward(logits, tokens, mems)
-        if len(tokens.shape) == 3:
+        if len(tokens.shape) == 3 and num_beams == 1:
             num_beams = tokens.shape[1]
             position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            attention_mask_shape = attention_mask.shape[-3:]
             attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
-                batch_size * num_beams, -1, -1, -1)
+                batch_size * num_beams, *attention_mask_shape)
         if strategy.is_done:
             break
     return strategy.finalize(tokens, mems)
@@ -159,12 +157,19 @@ class ModelForEvaluation(torch.nn.Module):
 
         output_targets = []
         context_length = seqs.shape[1]
-        for line in output:
-            line = line.tolist()
-            unfinished = line.index(-1) if -1 in line else len(line)
-            if line[unfinished - 1] in strategy.end_tokens:
-                unfinished -= 1
-            line = line[context_length:unfinished]
-            output_targets.append(line)
-
+        for lines in output:
+            output_target = []
+            if not isinstance(lines, list):
+                lines = [lines]
+            for line in lines:
+                line = line.tolist()
+                unfinished = line.index(-1) if -1 in line else len(line)
+                if line[unfinished - 1] in strategy.end_tokens:
+                    unfinished -= 1
+                line = line[context_length:unfinished]
+                output_target.append(line)
+            if not return_all_beams:
+                output_targets.append(output_target[0])
+            else:
+                output_targets.append(output_target)
         return output_targets

+ 11 - 8
generation/strategies.py

@@ -81,7 +81,7 @@ class BeamSearchStrategy:
 
     def _add_end_beams(self, score, beam, batch_idx):
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
-        for i in range(len(self.end_beams), -1, -1):
+        for i in range(len(self.end_beams[batch_idx]), -1, -1):
             if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
                 break
         self.end_beams[batch_idx].insert(i, beam)
@@ -90,6 +90,10 @@ class BeamSearchStrategy:
         self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
         self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
 
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
+
     def forward(self, logits, tokens, mems):
         if len(logits.shape) == 2:
             logits = logits.unsqueeze(1)
@@ -135,8 +139,6 @@ class BeamSearchStrategy:
         next_tokens = next_tokens % vocab_size
 
         # select out end beams or continue beams
-        if mems.shape[2] < self.num_beams:
-            mems = mems.expand(-1, batch_size, self.num_beams, -1, -1)
         beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
         for batch_idx in range(batch_size):
             beam_continue = []
@@ -156,7 +158,7 @@ class BeamSearchStrategy:
                         bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
                         # TODO ngram=1
                         ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
-                        bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
+                        bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
                         bans_continue.append(bans)
                 else:
                     break
@@ -165,7 +167,7 @@ class BeamSearchStrategy:
             score_continue_batch.append(scores_continue)
             self.cached_beam_ngram_bans[batch_idx] = bans_continue
         tokens = torch.stack(beam_continue_batch)
-        mems = torch.stack(mems_continue_batch)
+        mems = torch.stack(mems_continue_batch, dim=1)
         self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
         self.length_generated += 1
         for batch_idx in range(self.batch_size):
@@ -182,12 +184,13 @@ class BeamSearchStrategy:
 
     def finalize(self, tokens, mems):
         if self.consider_end:
-            for batch_idx in range(tokens.shape[0]):
+            batch_size = tokens.shape[0]
+            for batch_idx in range(batch_size):
                 if not self._is_done[batch_idx]:
-                    for i in range(tokens.shape[0]):
+                    for i in range(batch_size):
                         self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[i], batch_idx)
             mems = None
-            ret = self.end_beams
+            ret = self.end_beams[:batch_size]
         else:
             ret = tokens
         self._init_cache()