2
0
Эх сурвалжийг харах

Implement batch generation
beam search, use_task_mask and unidirectional not tested

Zhengxiao Du 3 жил өмнө
parent
commit
bb9fbe4bfc

+ 2 - 1
evaluation/configs.py

@@ -50,4 +50,5 @@ class GenerationTaskConfig(BaseConfig):
     max_gen_length: int = 128
     max_gen_length: int = 128
 
 
     def __post_init__(self):
     def __post_init__(self):
-        assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
+        pass
+        # assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"

+ 37 - 7
evaluation/dataset.py

@@ -76,6 +76,34 @@ class GenerationTaskDataset(EvaluationDataset):
             text = text[len(text) - text_length : len(text)]
             text = text[len(text) - text_length : len(text)]
         return {"text": text, "targets": targets}
         return {"text": text, "targets": targets}
 
 
+    @property
+    def has_collate_fn(self) -> bool:
+        return True
+
+    def collate_fn(self, samples):
+        TILE = 32
+        length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
+
+        token_batch, position_id_batch, attention_mask_batch = [], [], []
+        context_length_batch, target_position_id_batch = [], []
+
+        for sample in samples:
+            token, position_id, attention_mask = pad_batch(
+                sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
+            )
+            token_batch.append(token)
+            position_id_batch.append(position_id)
+            attention_mask_batch.append(attention_mask)
+            context_length_batch.append(sample['context_length'])
+            target_position_id_batch.append(sample['target_position_id'])
+        return {
+            "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
+            "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
+            "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
+            "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
+            "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
+        }
+
     @staticmethod
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
         tokenizer = get_tokenizer()
         tokenizer = get_tokenizer()
@@ -98,20 +126,22 @@ class GenerationTaskDataset(EvaluationDataset):
             else:
             else:
                 token = np.concatenate((token, [mask_id, sop_id]))
                 token = np.concatenate((token, [mask_id, sop_id]))
         context_length = len(token)
         context_length = len(token)
-        max_seq_length = context_length + max_gen_length
 
 
-        position_id = np.arange(0, max_seq_length, dtype=np.int64)
+        position_id = np.arange(0, context_length, dtype=np.int64)
+        target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
         if not use_task_mask:
-            position_id[context_length - 1 :] = mask_position
+            position_id[context_length - 1:] = mask_position
+            target_position_id[:] = mask_position
 
 
-        attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
+        attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
         if not unidirectional:
         if not unidirectional:
             attention_mask[: context_length - 1, : context_length - 1] = 1
             attention_mask[: context_length - 1, : context_length - 1] = 1
 
 
         item = {
         item = {
-            "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
-            "position_ids": position_id,
-            "attention_mask": attention_mask < 0.5,
+            "token": token,
+            "position_id": position_id,
+            "target_position_id": target_position_id,
+            "attention_mask": attention_mask,
             "context_length": context_length,
             "context_length": context_length,
         }
         }
         return item
         return item

+ 95 - 13
evaluation/model.py

@@ -2,7 +2,78 @@ import torch
 
 
 from typing import List, Union
 from typing import List, Union
 
 
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
+
+
+def batch_filling_sequence(
+        model,
+        seqs,
+        context_lengths,
+        strategy,
+        max_memory_length=100000,
+        get_masks_and_position_ids=get_masks_and_position_ids_default,
+        mems=None,
+        **kw_args
+        ):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
+            cache, should be first mems.shape[1] parts of context_tokens.
+            mems are the first-level citizens here, but we don't assume what is memorized.
+            input mems are used when multi-phase generation.
+    '''
+    assert len(seqs.shape) == 2
+
+    # building the initial tokens, attention_mask, and position_ids
+    batch_size, context_length = seqs.shape
+    seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
+    tokens = seqs[..., :context_length]
+    if attention_mask.dtype != torch.bool:
+        attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
+    # initialize generation
+    counter = context_length - 1 # Last fixed index is ``counter''
+    index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
+    num_beams = 1
+    # step-by-step generation
+    while counter < seqs.shape[1] - 1:
+        # Now, we want to generate seq[counter + 1],
+        # token[:, index: counter+1] needs forwarding.
+        # forward
+        if num_beams > 1:
+            tokens = tokens.reshape(batch_size * num_beams, -1)
+            mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
+        logits, *output_per_layers = model(
+            tokens[:, index:],
+            position_ids[..., index: counter+1],
+            attention_mask[..., index: counter+1, :counter+1], # TODO memlen
+            mems=mems,
+            **kw_args
+        )
+        mem_kv = [o['mem_kv'] for o in output_per_layers]
+        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
+        if counter == context_length - 1:
+            logits = logits[torch.arange(batch_size), context_lengths - 1]
+        else:
+            logits = logits[:, -1]
+        # if torch.distributed.get_rank() == 0:
+        #     breakpoint()
+        # torch.distributed.barrier()
+        counter += 1
+        index = counter
+        # sampling
+        if num_beams > 1:
+            logits = logits.reshape(batch_size, num_beams, -1)
+            tokens = tokens.reshape(batch_size, num_beams, -1)
+            mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
+        tokens, mems = strategy.forward(logits, tokens, mems)
+        if len(tokens.shape) == 3:
+            num_beams = tokens.shape[1]
+            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
+                batch_size * num_beams, -1, -1, -1)
+        if strategy.is_done:
+            break
+    return strategy.finalize(tokens, mems)
 
 
 
 
 class ModelForEvaluation(torch.nn.Module):
 class ModelForEvaluation(torch.nn.Module):
@@ -47,28 +118,39 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
                 log_probs.append(log_probs_single)
         return log_probs
         return log_probs
 
 
-    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
+    def generate_text(self, sample, strategy, return_all_beams=False, max_gen_length=128) -> Union[
+        List[int], List[List[int]]]:
         """
         """
         @return: A list of text model generated, sorted by score in descending order
         @return: A list of text model generated, sorted by score in descending order
         """
         """
 
 
-        seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
-        context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
-        seq[context_length:] = -1
+        seqs = sample["tokens"].to(device=torch.cuda.current_device()).long()
+        context_lengths = sample["context_length"].long()
 
 
         def get_masks_and_position_ids(seq):
         def get_masks_and_position_ids(seq):
-            tokens = seq.unsqueeze(0)
-            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
-            position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
+            batch_size = seq.shape[0]
+            tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+            position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
+            position_ids = position_ids.to(device=torch.cuda.current_device()).long()
+            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device())
+            context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
+                                                                                                             max_gen_length,
+                                                                                                             1)
+            causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
+            generation_mask = torch.cat(
+                (context_mask, causal_mask), dim=-1)
+            attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode='constant', value=1)
+            attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
+            attention_mask = attention_mask.bool().unsqueeze(1)
             return tokens, attention_mask, position_ids
             return tokens, attention_mask, position_ids
 
 
         self.model.eval()
         self.model.eval()
         with torch.no_grad():
         with torch.no_grad():
-            output = filling_sequence(
+            output = batch_filling_sequence(
                 self.model,
                 self.model,
-                seq,
+                seqs,
+                context_lengths,
                 get_masks_and_position_ids=get_masks_and_position_ids,
                 get_masks_and_position_ids=get_masks_and_position_ids,
-                batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
                 strategy=strategy,
                 strategy=strategy,
             )[0]
             )[0]
 
 
@@ -76,7 +158,7 @@ class ModelForEvaluation(torch.nn.Module):
             output = list(output)
             output = list(output)
 
 
         output_targets = []
         output_targets = []
-
+        context_length = seqs.shape[1]
         for line in output:
         for line in output:
             line = line.tolist()
             line = line.tolist()
             unfinished = line.index(-1) if -1 in line else len(line)
             unfinished = line.index(-1) if -1 in line else len(line)
@@ -85,4 +167,4 @@ class ModelForEvaluation(torch.nn.Module):
             line = line[context_length:unfinished]
             line = line[context_length:unfinished]
             output_targets.append(line)
             output_targets.append(line)
 
 
-        return output_targets if return_all_beams else output_targets[0]
+        return output_targets

+ 7 - 5
evaluation/tasks.py

@@ -9,10 +9,9 @@ from glob import glob
 from os.path import join, relpath
 from os.path import join, relpath
 from collections import defaultdict
 from collections import defaultdict
 
 
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 
 
-from generation import BeamSearchStrategy
+from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig
 from .model import ModelForEvaluation
 from .model import ModelForEvaluation
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset
@@ -171,9 +170,11 @@ class GenerationTask(BaseTask, ABC):
 
 
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         if self.config.sampling_strategy == "BaseStrategy":
         if self.config.sampling_strategy == "BaseStrategy":
-            self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
+            self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
+                                         end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
             self.strategy = BeamSearchStrategy(
+                self.config.micro_batch_size,
                 self.config.num_beams,
                 self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
                 consider_end=True,
@@ -188,8 +189,9 @@ class GenerationTask(BaseTask, ABC):
     def predict_single_batch(self, batch) -> List[List[int]]:
     def predict_single_batch(self, batch) -> List[List[int]]:
         # micro batch size = 1 for generation task,
         # micro batch size = 1 for generation task,
         # but we still need to return a list of predictions for consistency
         # but we still need to return a list of predictions for consistency
-        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
-        return [output]
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False,
+                                          max_gen_length=self.config.max_gen_length)
+        return output
 
 
 
 
 class MultiChoiceTask(BaseTask, ABC):
 class MultiChoiceTask(BaseTask, ABC):

+ 1 - 1
generation/__init__.py

@@ -1 +1 @@
-from .strategies import BeamSearchStrategy
+from .strategies import BaseStrategy, BeamSearchStrategy

+ 121 - 59
generation/strategies.py

@@ -1,10 +1,56 @@
+import numpy as np
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
+from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
+
+class BaseStrategy:
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
+        self.batch_size = batch_size
+        self.invalid_slices = invalid_slices
+        self.temperature = temperature
+        self.topk = top_k
+        self.top_p = top_p
+        self.eps = eps
+        if end_tokens is None:
+            end_tokens = []
+        self.end_tokens = end_tokens
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
+
+    def forward(self, logits, tokens, mems, temperature=None):
+        batch_size = tokens.shape[0]
+        if temperature is None:
+            temperature = self.temperature
+        logits = logits / temperature
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -65504
+
+        # logits = top_k_logits(logits, self.topk, self.top_p)
+        # probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        # pred = torch.multinomial(probs, num_samples=1)
+        pred = torch.argmax(logits, dim=-1)
+        for i in range(self.batch_size):
+            if i >= batch_size:
+                self._is_done[i] = True
+            elif self._is_done[i]:
+                pred[i] = -1
+            elif pred[i].item() in self.end_tokens:
+                self._is_done[i] = True
+        tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+        return tokens, mems
 
 
 
 
 class BeamSearchStrategy:
 class BeamSearchStrategy:
     def __init__(
     def __init__(
         self,
         self,
+        batch_size,
         num_beams,
         num_beams,
         length_penalty=1.0,
         length_penalty=1.0,
         consider_end=False,
         consider_end=False,
@@ -14,6 +60,7 @@ class BeamSearchStrategy:
         min_gen_length=0,
         min_gen_length=0,
         deterministic=False,
         deterministic=False,
     ):
     ):
+        self.batch_size = batch_size
         self.num_beams = num_beams
         self.num_beams = num_beams
         self.length_penalty = length_penalty
         self.length_penalty = length_penalty
         self.end_tokens = end_tokens
         self.end_tokens = end_tokens
@@ -25,26 +72,30 @@ class BeamSearchStrategy:
         self._init_cache()
         self._init_cache()
 
 
     def _init_cache(self):
     def _init_cache(self):
-        self.end_beams = []  # list of LongTensors
-        self.end_beams_penalized_scores = []  # list of LongTensors
+        self.end_beams = [[] for _ in range(self.batch_size)]  # list of LongTensors
+        self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)]  # list of LongTensors
         self.cached_beam_scores = 0  # [batch_size]
         self.cached_beam_scores = 0  # [batch_size]
-        self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
+        self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
         self.length_generated = 0
         self.length_generated = 0
-        self.is_done = False
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
 
-    def _add_end_beams(self, score, beam):
+    def _add_end_beams(self, score, beam, batch_idx):
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
         for i in range(len(self.end_beams), -1, -1):
         for i in range(len(self.end_beams), -1, -1):
-            if i == 0 or score < self.end_beams_penalized_scores[i - 1]:
+            if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
                 break
                 break
-        self.end_beams.insert(i, beam)
-        self.end_beams_penalized_scores.insert(i, score)
+        self.end_beams[batch_idx].insert(i, beam)
+        self.end_beams_penalized_scores[batch_idx].insert(i, score)
 
 
-        self.end_beams = self.end_beams[: self.num_beams]
-        self.end_beams_penalized_scores = self.end_beams_penalized_scores[: self.num_beams]
+        self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
+        self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
 
 
     def forward(self, logits, tokens, mems):
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        if len(logits.shape) == 2:
+            logits = logits.unsqueeze(1)
+            tokens = tokens.unsqueeze(1)
+            mems = mems.unsqueeze(2)
+        batch_size, num_beams, vocab_size = logits.shape
         seq_len = tokens.shape[-1]
         seq_len = tokens.shape[-1]
         logits = logits.float()
         logits = logits.float()
         for invalid_slice in self.invalid_slices:
         for invalid_slice in self.invalid_slices:
@@ -53,77 +104,88 @@ class BeamSearchStrategy:
             for end_token in self.end_tokens:
             for end_token in self.end_tokens:
                 logits[..., end_token] = -65504
                 logits[..., end_token] = -65504
         if self.ngram > 0 and seq_len > self.ngram:
         if self.ngram > 0 and seq_len > self.ngram:
-            for i in range(batch_size):
-                ngram_prefix = tokens[i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
-                for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
-                    logits[i, banned_index] = -65504
+            for batch_idx in range(batch_size):
+                for i in range(num_beams):
+                    ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
+                    for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
+                        logits[batch_idx, i, banned_index] = -65504
 
 
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         prev_scores = self.cached_beam_scores
         prev_scores = self.cached_beam_scores
-        if isinstance(self.cached_beam_scores, torch.Tensor):
-            prev_scores = prev_scores[:, None].expand_as(next_token_scores)
+        if isinstance(prev_scores, torch.Tensor):
+            prev_scores = prev_scores[..., None].expand_as(next_token_scores)
         next_token_scores = next_token_scores + prev_scores
         next_token_scores = next_token_scores + prev_scores
 
 
-        next_token_scores = next_token_scores.view(batch_size * vocab_size)
+        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
 
-        probs = F.softmax(next_token_scores, dim=0)
+        probs = F.softmax(next_token_scores, dim=-1)
         if self.deterministic:
         if self.deterministic:
-            if mems.shape[1] < batch_size:  # First token
-                probs = probs[:vocab_size]
+            if mems.shape[2] < self.num_beams:  # First token
+                probs = probs[..., :vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else:
         else:
             next_tokens = torch.multinomial(
             next_tokens = torch.multinomial(
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
             )  # [2*nb]
             )  # [2*nb]
-        next_token_scores = next_token_scores[next_tokens]
-        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
-        next_tokens = next_tokens[_indices]
+        next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
 
 
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_tokens = next_tokens % vocab_size
         next_tokens = next_tokens % vocab_size
 
 
         # select out end beams or continue beams
         # select out end beams or continue beams
-        if mems.shape[1] < batch_size:
-            mems = mems.expand(-1, batch_size, -1, -1)
-        beam_continue = []
-        scores_continue = []
-        bans_continue = []
-        mems_contiue = []
-        for i in range(len(next_tokens)):
-            beam = torch.cat((tokens[next_indices[i]], next_tokens[i : i + 1]))
-            if int(next_tokens[i]) in self.end_tokens:
-                self._add_end_beams(next_token_scores[i], beam)
-            elif len(beam_continue) < self.num_beams:
-                beam_continue.append(beam)
-                mems_contiue.append(mems[:, next_indices[i]])
-                # update caches
-                scores_continue.append(next_token_scores[i])
-                if self.ngram > 0:
-                    bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
-                    ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram - 1) :].tolist())  # TODO ngram=1
-                    bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
-                    bans_continue.append(bans)
-            else:
-                break
-        tokens = torch.stack(beam_continue)
-        mems = torch.stack(mems_contiue, dim=1)
-        self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
-        self.cached_beam_ngram_bans = bans_continue
+        if mems.shape[2] < self.num_beams:
+            mems = mems.expand(-1, batch_size, self.num_beams, -1, -1)
+        beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
+        for batch_idx in range(batch_size):
+            beam_continue = []
+            scores_continue = []
+            bans_continue = []
+            mems_contiue = []
+            for i in range(len(next_tokens[batch_idx])):
+                beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
+                if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
+                    self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
+                elif len(beam_continue) < self.num_beams:
+                    beam_continue.append(beam)
+                    mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
+                    # update caches
+                    scores_continue.append(next_token_scores[batch_idx, i])
+                    if self.ngram > 0:
+                        bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
+                        # TODO ngram=1
+                        ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
+                        bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
+                        bans_continue.append(bans)
+                else:
+                    break
+            beam_continue_batch.append(torch.stack(beam_continue))
+            mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
+            score_continue_batch.append(scores_continue)
+            self.cached_beam_ngram_bans[batch_idx] = bans_continue
+        tokens = torch.stack(beam_continue_batch)
+        mems = torch.stack(mems_continue_batch)
+        self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
         self.length_generated += 1
         self.length_generated += 1
-
-        if (
-            len(self.end_beams) == self.num_beams
-            and self.end_beams_penalized_scores[-1]
-            >= self.cached_beam_scores.max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
-        ):  # We're done if none of current tokens will better than the worst in end_beams
-            self.is_done = True
+        for batch_idx in range(self.batch_size):
+            if batch_idx >= batch_size:
+                self._is_done[batch_idx] = True
+            elif (
+                len(self.end_beams[batch_idx]) == self.num_beams
+                and self.end_beams_penalized_scores[batch_idx][-1]
+                >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
+            ):  # We're done if none of current tokens will better than the worst in end_beams
+                self._is_done[batch_idx] = True
 
 
         return tokens, mems
         return tokens, mems
 
 
     def finalize(self, tokens, mems):
     def finalize(self, tokens, mems):
         if self.consider_end:
         if self.consider_end:
-            for i in range(tokens.shape[0]):
-                self._add_end_beams(self.cached_beam_scores[i], tokens[i])
+            for batch_idx in range(tokens.shape[0]):
+                if not self._is_done[batch_idx]:
+                    for i in range(tokens.shape[0]):
+                        self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[i], batch_idx)
             mems = None
             mems = None
             ret = self.end_beams
             ret = self.end_beams
         else:
         else: