Browse Source

Merge branch 'batch-generation' of github.com:duzx16/GLM-130B

Sengxian 3 years ago
parent
commit
4f5910ccd2
7 changed files with 294 additions and 112 deletions
  1. 0 3
      evaluation/configs.py
  2. 37 7
      evaluation/dataset.py
  3. 108 20
      evaluation/model.py
  4. 7 5
      evaluation/tasks.py
  5. 16 15
      generate.py
  6. 1 1
      generation/__init__.py
  7. 125 61
      generation/strategies.py

+ 0 - 3
evaluation/configs.py

@@ -50,9 +50,6 @@ class GenerationTaskConfig(BaseConfig):
     min_gen_length: int = 0
     min_gen_length: int = 0
     max_gen_length: int = 128
     max_gen_length: int = 128
 
 
-    def __post_init__(self):
-        assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
-
 
 
 @dataclass
 @dataclass
 class LanguageModelTaskConfig(BaseConfig):
 class LanguageModelTaskConfig(BaseConfig):

+ 37 - 7
evaluation/dataset.py

@@ -84,6 +84,34 @@ class GenerationTaskDataset(EvaluationDataset):
             text = text[len(text) - text_length : len(text)]
             text = text[len(text) - text_length : len(text)]
         return {"text": text, "targets": targets}
         return {"text": text, "targets": targets}
 
 
+    @property
+    def has_collate_fn(self) -> bool:
+        return True
+
+    def collate_fn(self, samples):
+        TILE = 32
+        length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
+
+        token_batch, position_id_batch, attention_mask_batch = [], [], []
+        context_length_batch, target_position_id_batch = [], []
+
+        for sample in samples:
+            token, position_id, attention_mask = pad_batch(
+                sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
+            )
+            token_batch.append(token)
+            position_id_batch.append(position_id)
+            attention_mask_batch.append(attention_mask)
+            context_length_batch.append(sample['context_length'])
+            target_position_id_batch.append(sample['target_position_id'])
+        return {
+            "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
+            "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
+            "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
+            "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
+            "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
+        }
+
     @staticmethod
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
         tokenizer = get_tokenizer()
         tokenizer = get_tokenizer()
@@ -106,20 +134,22 @@ class GenerationTaskDataset(EvaluationDataset):
             else:
             else:
                 token = np.concatenate((token, [mask_id, sop_id]))
                 token = np.concatenate((token, [mask_id, sop_id]))
         context_length = len(token)
         context_length = len(token)
-        max_seq_length = context_length + max_gen_length
 
 
-        position_id = np.arange(0, max_seq_length, dtype=np.int64)
+        position_id = np.arange(0, context_length, dtype=np.int64)
+        target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
         if not use_task_mask:
-            position_id[context_length - 1 :] = mask_position
+            position_id[context_length - 1:] = mask_position
+            target_position_id[:] = mask_position
 
 
-        attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
+        attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
         if not unidirectional:
         if not unidirectional:
             attention_mask[: context_length - 1, : context_length - 1] = 1
             attention_mask[: context_length - 1, : context_length - 1] = 1
 
 
         item = {
         item = {
-            "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
-            "position_ids": position_id,
-            "attention_mask": attention_mask < 0.5,
+            "token": token,
+            "position_id": position_id,
+            "target_position_id": target_position_id,
+            "attention_mask": attention_mask,
             "context_length": context_length,
             "context_length": context_length,
         }
         }
         return item
         return item

+ 108 - 20
evaluation/model.py

@@ -2,10 +2,79 @@ import torch
 
 
 from typing import List, Union
 from typing import List, Union
 
 
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
 from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 
 
 
 
+def batch_filling_sequence(
+        model,
+        seqs,
+        context_lengths,
+        strategy,
+        max_memory_length=100000,
+        get_masks_and_position_ids=get_masks_and_position_ids_default,
+        mems=None,
+        **kw_args
+        ):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
+            cache, should be first mems.shape[1] parts of context_tokens.
+            mems are the first-level citizens here, but we don't assume what is memorized.
+            input mems are used when multi-phase generation.
+    '''
+    assert len(seqs.shape) == 2
+
+    # building the initial tokens, attention_mask, and position_ids
+    batch_size, context_length = seqs.shape
+    seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
+    tokens = seqs[..., :context_length]
+    if attention_mask.dtype != torch.bool:
+        attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
+    # initialize generation
+    counter = context_length - 1 # Last fixed index is ``counter''
+    index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
+    num_beams = 1
+    # step-by-step generation
+    while counter < seqs.shape[1] - 1:
+        # Now, we want to generate seq[counter + 1],
+        # token[:, index: counter+1] needs forwarding.
+        # forward
+        if num_beams > 1:
+            tokens = tokens.reshape(batch_size * num_beams, -1)
+            mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
+        logits, *output_per_layers = model(
+            tokens[:, index:],
+            position_ids[..., index: counter+1],
+            attention_mask[..., index: counter+1, :counter+1], # TODO memlen
+            mems=mems,
+            **kw_args
+        )
+        mem_kv = [o['mem_kv'] for o in output_per_layers]
+        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
+        if counter == context_length - 1:
+            logits = logits[torch.arange(batch_size), context_lengths - 1]
+        else:
+            logits = logits[:, -1]
+        counter += 1
+        index = counter
+        # sampling
+        if num_beams > 1:
+            logits = logits.reshape(batch_size, num_beams, -1)
+            tokens = tokens.reshape(batch_size, num_beams, -1)
+            mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
+        tokens, mems = strategy.forward(logits, tokens, mems)
+        if len(tokens.shape) == 3 and num_beams == 1:
+            num_beams = tokens.shape[1]
+            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            attention_mask_shape = attention_mask.shape[-3:]
+            attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
+                batch_size * num_beams, *attention_mask_shape)
+        if strategy.is_done:
+            break
+    return strategy.finalize(tokens, mems)
+
+
 class ModelForEvaluation(torch.nn.Module):
 class ModelForEvaluation(torch.nn.Module):
     def __init__(self, model):
     def __init__(self, model):
         super().__init__()
         super().__init__()
@@ -48,28 +117,39 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
                 log_probs.append(log_probs_single)
         return log_probs
         return log_probs
 
 
-    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
+    def generate_text(self, sample, strategy, return_all_beams=False, max_gen_length=128) -> Union[
+        List[int], List[List[int]]]:
         """
         """
         @return: A list of text model generated, sorted by score in descending order
         @return: A list of text model generated, sorted by score in descending order
         """
         """
 
 
-        seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
-        context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
-        seq[context_length:] = -1
+        seqs = sample["tokens"].to(device=torch.cuda.current_device()).long()
+        context_lengths = sample["context_length"].long()
 
 
         def get_masks_and_position_ids(seq):
         def get_masks_and_position_ids(seq):
-            tokens = seq.unsqueeze(0)
-            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
-            position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
+            batch_size = seq.shape[0]
+            tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+            position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
+            position_ids = position_ids.to(device=torch.cuda.current_device()).long()
+            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device())
+            context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
+                                                                                                             max_gen_length,
+                                                                                                             1)
+            causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
+            generation_mask = torch.cat(
+                (context_mask, causal_mask), dim=-1)
+            attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode='constant', value=1)
+            attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
+            attention_mask = attention_mask.bool().unsqueeze(1)
             return tokens, attention_mask, position_ids
             return tokens, attention_mask, position_ids
 
 
         self.model.eval()
         self.model.eval()
         with torch.no_grad():
         with torch.no_grad():
-            output = filling_sequence(
+            output = batch_filling_sequence(
                 self.model,
                 self.model,
-                seq,
+                seqs,
+                context_lengths,
                 get_masks_and_position_ids=get_masks_and_position_ids,
                 get_masks_and_position_ids=get_masks_and_position_ids,
-                batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
                 strategy=strategy,
                 strategy=strategy,
             )[0]
             )[0]
 
 
@@ -77,16 +157,24 @@ class ModelForEvaluation(torch.nn.Module):
             output = list(output)
             output = list(output)
 
 
         output_targets = []
         output_targets = []
+        context_length = seqs.shape[1]
+        for lines in output:
+            output_target = []
+            if not isinstance(lines, list):
+                lines = [lines]
+            for line in lines:
+                line = line.tolist()
+                unfinished = line.index(-1) if -1 in line else len(line)
+                if line[unfinished - 1] in strategy.end_tokens:
+                    unfinished -= 1
+                line = line[context_length:unfinished]
+                output_target.append(line)
+            if not return_all_beams:
+                output_targets.append(output_target[0])
+            else:
+                output_targets.append(output_target)
+        return output_targets
 
 
-        for line in output:
-            line = line.tolist()
-            unfinished = line.index(-1) if -1 in line else len(line)
-            if line[unfinished - 1] in strategy.end_tokens:
-                unfinished -= 1
-            line = line[context_length:unfinished]
-            output_targets.append(line)
-
-        return output_targets if return_all_beams else output_targets[0]
 
 
     def calculate_loss(self, batch) -> List[float]:
     def calculate_loss(self, batch) -> List[float]:
         tokens, position_ids, attention_mask = self.process_data(batch)
         tokens, position_ids, attention_mask = self.process_data(batch)

+ 7 - 5
evaluation/tasks.py

@@ -9,10 +9,9 @@ from glob import glob
 from os.path import join, relpath
 from os.path import join, relpath
 from collections import defaultdict
 from collections import defaultdict
 
 
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 
 
-from generation import BeamSearchStrategy
+from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
 from .model import ModelForEvaluation
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
@@ -171,9 +170,11 @@ class GenerationTask(BaseTask, ABC):
 
 
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         if self.config.sampling_strategy == "BaseStrategy":
         if self.config.sampling_strategy == "BaseStrategy":
-            self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
+            self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
+                                         end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
             self.strategy = BeamSearchStrategy(
+                self.config.micro_batch_size,
                 self.config.num_beams,
                 self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
                 consider_end=True,
@@ -188,8 +189,9 @@ class GenerationTask(BaseTask, ABC):
     def predict_single_batch(self, batch) -> List[List[int]]:
     def predict_single_batch(self, batch) -> List[List[int]]:
         # micro batch size = 1 for generation task,
         # micro batch size = 1 for generation task,
         # but we still need to return a list of predictions for consistency
         # but we still need to return a list of predictions for consistency
-        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
-        return [output]
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False,
+                                          max_gen_length=self.config.max_gen_length)
+        return output
 
 
 
 
 class MultiChoiceTask(BaseTask, ABC):
 class MultiChoiceTask(BaseTask, ABC):

+ 16 - 15
generate.py

@@ -7,9 +7,8 @@ from functools import partial
 from typing import List, Tuple
 from typing import List, Tuple
 
 
 from SwissArmyTransformer import mpu
 from SwissArmyTransformer import mpu
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
-from generation import BeamSearchStrategy
+from evaluation.model import batch_filling_sequence
+from generation import BeamSearchStrategy, BaseStrategy
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
 from initialize import initialize, initialize_model_and_tokenizer
 from initialize import initialize, initialize_model_and_tokenizer
 
 
@@ -31,16 +30,16 @@ def isEnglish(s):
         return True
         return True
 
 
 
 
-def get_masks_and_position_ids(seq, mask_position, context_length, gmask=False):
-    tokens = seq.unsqueeze(0)
-
-    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
+    context_length = seq.shape[1]
+    tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+    attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
     attention_mask.tril_()
     attention_mask.tril_()
     attention_mask[..., : context_length - 1] = 1
     attention_mask[..., : context_length - 1] = 1
     attention_mask.unsqueeze_(1)
     attention_mask.unsqueeze_(1)
     attention_mask = (attention_mask < 0.5).bool()
     attention_mask = (attention_mask < 0.5).bool()
 
 
-    position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
+    position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
     if not gmask:
     if not gmask:
         position_ids[context_length - 1 :] = mask_position
         position_ids[context_length - 1 :] = mask_position
 
 
@@ -99,25 +98,25 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
         output_list = []
         output_list = []
 
 
         input_seq = torch.cuda.LongTensor(
         input_seq = torch.cuda.LongTensor(
-            seq + [tokenizer.get_command("sop")] + [-1] * (args.out_seq_length - len(seq) - 1),
+            [seq + [tokenizer.get_command("sop")]],
             device=args.device,
             device=args.device,
         )
         )
-        output, _ = filling_sequence(
+        output, _ = batch_filling_sequence(
             model,
             model,
             input_seq,
             input_seq,
-            batch_size=num_output,
+            torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
             strategy=strategy,
             strategy=strategy,
-            log_attention_weights=None,
             get_masks_and_position_ids=partial(
             get_masks_and_position_ids=partial(
                 get_masks_and_position_ids,
                 get_masks_and_position_ids,
                 mask_position=mask_position,
                 mask_position=mask_position,
-                context_length=len(seq) + 1,
+                max_gen_length=args.out_seq_length - input_seq.shape[-1],
                 gmask=use_gmask,
                 gmask=use_gmask,
             ),
             ),
         )
         )
         if isinstance(output, torch.Tensor):  # different strategies
         if isinstance(output, torch.Tensor):  # different strategies
             output = list(output)
             output = list(output)
-
+        else:
+            output = output[0]
         output_list.extend(output)
         output_list.extend(output)
 
 
         # clip -1s and fill back generated things into seq
         # clip -1s and fill back generated things into seq
@@ -160,9 +159,11 @@ def main(args):
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
 
 
     if args.sampling_strategy == "BaseStrategy":
     if args.sampling_strategy == "BaseStrategy":
-        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens)
+        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                end_tokens=end_tokens)
     elif args.sampling_strategy == "BeamSearchStrategy":
     elif args.sampling_strategy == "BeamSearchStrategy":
         strategy = BeamSearchStrategy(
         strategy = BeamSearchStrategy(
+            1,
             args.num_beams,
             args.num_beams,
             length_penalty=args.length_penalty,
             length_penalty=args.length_penalty,
             consider_end=True,
             consider_end=True,

+ 1 - 1
generation/__init__.py

@@ -1 +1 @@
-from .strategies import BeamSearchStrategy
+from .strategies import BaseStrategy, BeamSearchStrategy

+ 125 - 61
generation/strategies.py

@@ -1,10 +1,55 @@
+import numpy as np
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
+from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
+
+class BaseStrategy:
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
+        self.batch_size = batch_size
+        self.invalid_slices = invalid_slices
+        self.temperature = temperature
+        self.topk = top_k
+        self.top_p = top_p
+        self.eps = eps
+        if end_tokens is None:
+            end_tokens = []
+        self.end_tokens = end_tokens
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
+
+    def forward(self, logits, tokens, mems, temperature=None):
+        batch_size = tokens.shape[0]
+        if temperature is None:
+            temperature = self.temperature
+        logits = logits / temperature
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -65504
+
+        logits = top_k_logits(logits, self.topk, self.top_p)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        pred = torch.multinomial(probs, num_samples=1)
+        for i in range(self.batch_size):
+            if i >= batch_size:
+                self._is_done[i] = True
+            elif self._is_done[i]:
+                pred[i] = -1
+            elif pred[i].item() in self.end_tokens:
+                self._is_done[i] = True
+        tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+        return tokens, mems
 
 
 
 
 class BeamSearchStrategy:
 class BeamSearchStrategy:
     def __init__(
     def __init__(
         self,
         self,
+        batch_size,
         num_beams,
         num_beams,
         length_penalty=1.0,
         length_penalty=1.0,
         consider_end=False,
         consider_end=False,
@@ -14,6 +59,7 @@ class BeamSearchStrategy:
         min_gen_length=0,
         min_gen_length=0,
         deterministic=False,
         deterministic=False,
     ):
     ):
+        self.batch_size = batch_size
         self.num_beams = num_beams
         self.num_beams = num_beams
         self.length_penalty = length_penalty
         self.length_penalty = length_penalty
         self.end_tokens = end_tokens
         self.end_tokens = end_tokens
@@ -25,26 +71,34 @@ class BeamSearchStrategy:
         self._init_cache()
         self._init_cache()
 
 
     def _init_cache(self):
     def _init_cache(self):
-        self.end_beams = []  # list of LongTensors
-        self.end_beams_penalized_scores = []  # list of LongTensors
+        self.end_beams = [[] for _ in range(self.batch_size)]  # list of LongTensors
+        self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)]  # list of LongTensors
         self.cached_beam_scores = 0  # [batch_size]
         self.cached_beam_scores = 0  # [batch_size]
-        self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
+        self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
         self.length_generated = 0
         self.length_generated = 0
-        self.is_done = False
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
 
-    def _add_end_beams(self, score, beam):
+    def _add_end_beams(self, score, beam, batch_idx):
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
-        for i in range(len(self.end_beams), -1, -1):
-            if i == 0 or score < self.end_beams_penalized_scores[i - 1]:
+        for i in range(len(self.end_beams[batch_idx]), -1, -1):
+            if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
                 break
                 break
-        self.end_beams.insert(i, beam)
-        self.end_beams_penalized_scores.insert(i, score)
+        self.end_beams[batch_idx].insert(i, beam)
+        self.end_beams_penalized_scores[batch_idx].insert(i, score)
 
 
-        self.end_beams = self.end_beams[: self.num_beams]
-        self.end_beams_penalized_scores = self.end_beams_penalized_scores[: self.num_beams]
+        self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
+        self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
 
 
     def forward(self, logits, tokens, mems):
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        if len(logits.shape) == 2:
+            logits = logits.unsqueeze(1)
+            tokens = tokens.unsqueeze(1)
+            mems = mems.unsqueeze(2)
+        batch_size, num_beams, vocab_size = logits.shape
         seq_len = tokens.shape[-1]
         seq_len = tokens.shape[-1]
         logits = logits.float()
         logits = logits.float()
         for invalid_slice in self.invalid_slices:
         for invalid_slice in self.invalid_slices:
@@ -53,79 +107,89 @@ class BeamSearchStrategy:
             for end_token in self.end_tokens:
             for end_token in self.end_tokens:
                 logits[..., end_token] = -65504
                 logits[..., end_token] = -65504
         if self.ngram > 0 and seq_len > self.ngram:
         if self.ngram > 0 and seq_len > self.ngram:
-            for i in range(batch_size):
-                ngram_prefix = tokens[i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
-                for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
-                    logits[i, banned_index] = -65504
+            for batch_idx in range(batch_size):
+                for i in range(num_beams):
+                    ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
+                    for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
+                        logits[batch_idx, i, banned_index] = -65504
 
 
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         prev_scores = self.cached_beam_scores
         prev_scores = self.cached_beam_scores
-        if isinstance(self.cached_beam_scores, torch.Tensor):
-            prev_scores = prev_scores[:, None].expand_as(next_token_scores)
+        if isinstance(prev_scores, torch.Tensor):
+            prev_scores = prev_scores[..., None].expand_as(next_token_scores)
         next_token_scores = next_token_scores + prev_scores
         next_token_scores = next_token_scores + prev_scores
 
 
-        next_token_scores = next_token_scores.view(batch_size * vocab_size)
+        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
 
-        probs = F.softmax(next_token_scores, dim=0)
+        probs = F.softmax(next_token_scores, dim=-1)
         if self.deterministic:
         if self.deterministic:
-            if mems.shape[1] < batch_size:  # First token
-                probs = probs[:vocab_size]
+            if num_beams < self.num_beams:  # First token
+                probs = probs[..., :vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else:
         else:
             next_tokens = torch.multinomial(
             next_tokens = torch.multinomial(
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
             )  # [2*nb]
             )  # [2*nb]
-        next_token_scores = next_token_scores[next_tokens]
-        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
-        next_tokens = next_tokens[_indices]
+        next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
 
 
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_tokens = next_tokens % vocab_size
         next_tokens = next_tokens % vocab_size
 
 
         # select out end beams or continue beams
         # select out end beams or continue beams
-        if mems.shape[1] < batch_size:
-            mems = mems.expand(-1, batch_size, -1, -1)
-        beam_continue = []
-        scores_continue = []
-        bans_continue = []
-        mems_contiue = []
-        for i in range(len(next_tokens)):
-            beam = torch.cat((tokens[next_indices[i]], next_tokens[i : i + 1]))
-            if int(next_tokens[i]) in self.end_tokens:
-                self._add_end_beams(next_token_scores[i], beam)
-            elif len(beam_continue) < self.num_beams:
-                beam_continue.append(beam)
-                mems_contiue.append(mems[:, next_indices[i]])
-                # update caches
-                scores_continue.append(next_token_scores[i])
-                if self.ngram > 0:
-                    bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
-                    ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram - 1) :].tolist())  # TODO ngram=1
-                    bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
-                    bans_continue.append(bans)
-            else:
-                break
-        tokens = torch.stack(beam_continue)
-        mems = torch.stack(mems_contiue, dim=1)
-        self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
-        self.cached_beam_ngram_bans = bans_continue
+        beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
+        for batch_idx in range(batch_size):
+            beam_continue = []
+            scores_continue = []
+            bans_continue = []
+            mems_contiue = []
+            for i in range(len(next_tokens[batch_idx])):
+                beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
+                if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
+                    self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
+                elif len(beam_continue) < self.num_beams:
+                    beam_continue.append(beam)
+                    mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
+                    # update caches
+                    scores_continue.append(next_token_scores[batch_idx, i])
+                    if self.ngram > 0:
+                        bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
+                        # TODO ngram=1
+                        ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
+                        bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
+                        bans_continue.append(bans)
+                else:
+                    break
+            beam_continue_batch.append(torch.stack(beam_continue))
+            mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
+            score_continue_batch.append(scores_continue)
+            self.cached_beam_ngram_bans[batch_idx] = bans_continue
+        tokens = torch.stack(beam_continue_batch)
+        mems = torch.stack(mems_continue_batch, dim=1)
+        self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
         self.length_generated += 1
         self.length_generated += 1
-
-        if (
-            len(self.end_beams) == self.num_beams
-            and self.end_beams_penalized_scores[-1]
-            >= self.cached_beam_scores.max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
-        ):  # We're done if none of current tokens will better than the worst in end_beams
-            self.is_done = True
+        for batch_idx in range(self.batch_size):
+            if batch_idx >= batch_size:
+                self._is_done[batch_idx] = True
+            elif (
+                len(self.end_beams[batch_idx]) == self.num_beams
+                and self.end_beams_penalized_scores[batch_idx][-1]
+                >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
+            ):  # We're done if none of current tokens will better than the worst in end_beams
+                self._is_done[batch_idx] = True
 
 
         return tokens, mems
         return tokens, mems
 
 
     def finalize(self, tokens, mems):
     def finalize(self, tokens, mems):
         if self.consider_end:
         if self.consider_end:
-            for i in range(tokens.shape[0]):
-                self._add_end_beams(self.cached_beam_scores[i], tokens[i])
+            batch_size, num_beams = tokens.shape[:2]
+            for batch_idx in range(batch_size):
+                if not self._is_done[batch_idx]:
+                    for i in range(num_beams):
+                        self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
             mems = None
             mems = None
-            ret = self.end_beams
+            ret = self.end_beams[:batch_size]
         else:
         else:
             ret = tokens
             ret = tokens
         self._init_cache()
         self._init_cache()