瀏覽代碼

Merge pull request #13 from duzx16/batch-generation

Implement batch generation
Aohan Zeng 2 年之前
父節點
當前提交
335667a88d
共有 9 個文件被更改,包括 312 次插入128 次删除
  1. 0 3
      evaluation/configs.py
  2. 37 7
      evaluation/dataset.py
  3. 110 21
      evaluation/model.py
  4. 5 6
      evaluation/tasks.py
  5. 17 17
      generate.py
  6. 1 1
      generation/__init__.py
  7. 122 61
      generation/strategies.py
  8. 4 3
      tasks/lambada/strategy.py
  9. 16 9
      tasks/lambada/task.py

+ 0 - 3
evaluation/configs.py

@@ -50,9 +50,6 @@ class GenerationTaskConfig(BaseConfig):
     min_gen_length: int = 0
     min_gen_length: int = 0
     max_gen_length: int = 128
     max_gen_length: int = 128
 
 
-    def __post_init__(self):
-        assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
-
 
 
 @dataclass
 @dataclass
 class LanguageModelTaskConfig(BaseConfig):
 class LanguageModelTaskConfig(BaseConfig):

+ 37 - 7
evaluation/dataset.py

@@ -84,6 +84,34 @@ class GenerationTaskDataset(EvaluationDataset):
             text = text[len(text) - text_length : len(text)]
             text = text[len(text) - text_length : len(text)]
         return {"text": text, "targets": targets}
         return {"text": text, "targets": targets}
 
 
+    @property
+    def has_collate_fn(self) -> bool:
+        return True
+
+    def collate_fn(self, samples):
+        TILE = 32
+        length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
+
+        token_batch, position_id_batch, attention_mask_batch = [], [], []
+        context_length_batch, target_position_id_batch = [], []
+
+        for sample in samples:
+            token, position_id, attention_mask = pad_batch(
+                sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
+            )
+            token_batch.append(token)
+            position_id_batch.append(position_id)
+            attention_mask_batch.append(attention_mask)
+            context_length_batch.append(sample['context_length'])
+            target_position_id_batch.append(sample['target_position_id'])
+        return {
+            "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
+            "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
+            "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
+            "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
+            "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
+        }
+
     @staticmethod
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
         tokenizer = get_tokenizer()
         tokenizer = get_tokenizer()
@@ -106,20 +134,22 @@ class GenerationTaskDataset(EvaluationDataset):
             else:
             else:
                 token = np.concatenate((token, [mask_id, sop_id]))
                 token = np.concatenate((token, [mask_id, sop_id]))
         context_length = len(token)
         context_length = len(token)
-        max_seq_length = context_length + max_gen_length
 
 
-        position_id = np.arange(0, max_seq_length, dtype=np.int64)
+        position_id = np.arange(0, context_length, dtype=np.int64)
+        target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
         if not use_task_mask:
-            position_id[context_length - 1 :] = mask_position
+            position_id[context_length - 1:] = mask_position
+            target_position_id[:] = mask_position
 
 
-        attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
+        attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
         if not unidirectional:
         if not unidirectional:
             attention_mask[: context_length - 1, : context_length - 1] = 1
             attention_mask[: context_length - 1, : context_length - 1] = 1
 
 
         item = {
         item = {
-            "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
-            "position_ids": position_id,
-            "attention_mask": attention_mask < 0.5,
+            "token": token,
+            "position_id": position_id,
+            "target_position_id": target_position_id,
+            "attention_mask": attention_mask,
             "context_length": context_length,
             "context_length": context_length,
         }
         }
         return item
         return item

+ 110 - 21
evaluation/model.py

@@ -2,10 +2,79 @@ import torch
 
 
 from typing import List, Union
 from typing import List, Union
 
 
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
 from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 
 
 
 
+def batch_filling_sequence(
+        model,
+        seqs,
+        context_lengths,
+        strategy,
+        max_memory_length=100000,
+        get_masks_and_position_ids=get_masks_and_position_ids_default,
+        mems=None,
+        **kw_args
+        ):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
+            cache, should be first mems.shape[1] parts of context_tokens.
+            mems are the first-level citizens here, but we don't assume what is memorized.
+            input mems are used when multi-phase generation.
+    '''
+    assert len(seqs.shape) == 2
+
+    # building the initial tokens, attention_mask, and position_ids
+    batch_size, context_length = seqs.shape
+    seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
+    tokens = seqs[..., :context_length]
+    if attention_mask.dtype != torch.bool:
+        attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
+    # initialize generation
+    counter = context_length - 1 # Last fixed index is ``counter''
+    index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
+    num_beams = 1
+    # step-by-step generation
+    while counter < seqs.shape[1] - 1:
+        # Now, we want to generate seq[counter + 1],
+        # token[:, index: counter+1] needs forwarding.
+        # forward
+        tokens = tokens.reshape(batch_size * num_beams, -1)
+        mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
+        logits, *output_per_layers = model(
+            tokens[:, index:],
+            position_ids[..., index: counter+1],
+            attention_mask[..., index: counter+1, :counter+1], # TODO memlen
+            mems=mems,
+            **kw_args
+        )
+        mem_kv = [o['mem_kv'] for o in output_per_layers]
+        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
+        if counter == context_length - 1:
+            logits = logits[torch.arange(batch_size), context_lengths - 1]
+        else:
+            logits = logits[:, -1]
+        counter += 1
+        index = counter
+        # if torch.distributed.get_rank() == 0:
+        #     print(f"counter: {counter}: logits: {logits.float().abs().mean()}")
+        # sampling
+        logits = logits.reshape(batch_size, num_beams, -1)
+        tokens = tokens.reshape(batch_size, num_beams, -1)
+        mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
+        tokens, mems = strategy.forward(logits, tokens, mems)
+        if len(tokens.shape) == 3 and num_beams == 1:
+            num_beams = tokens.shape[1]
+            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            attention_mask_shape = attention_mask.shape[-3:]
+            attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
+                batch_size * num_beams, *attention_mask_shape)
+        if strategy.is_done:
+            break
+    return strategy.finalize(tokens, mems)
+
+
 class ModelForEvaluation(torch.nn.Module):
 class ModelForEvaluation(torch.nn.Module):
     def __init__(self, model):
     def __init__(self, model):
         super().__init__()
         super().__init__()
@@ -48,45 +117,65 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
                 log_probs.append(log_probs_single)
         return log_probs
         return log_probs
 
 
-    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
+    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
+        List[List[int]], List[List[List[int]]]]:
         """
         """
         @return: A list of text model generated, sorted by score in descending order
         @return: A list of text model generated, sorted by score in descending order
         """
         """
 
 
-        seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
-        context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
-        seq[context_length:] = -1
+        seqs = sample["tokens"].to(device=torch.cuda.current_device()).long()
+        context_lengths = sample["context_length"].long()
 
 
         def get_masks_and_position_ids(seq):
         def get_masks_and_position_ids(seq):
-            tokens = seq.unsqueeze(0)
-            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
-            position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
+            batch_size = seq.shape[0]
+            max_gen_length = sample['target_position_ids'].shape[-1]
+            tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+            position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
+            position_ids = position_ids.to(device=torch.cuda.current_device()).long()
+            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device())
+            context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
+                                                                                                             max_gen_length,
+                                                                                                             1)
+            causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
+            generation_mask = torch.cat(
+                (context_mask, causal_mask), dim=-1)
+            attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode='constant', value=1)
+            attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
+            attention_mask = attention_mask.bool().unsqueeze(1)
             return tokens, attention_mask, position_ids
             return tokens, attention_mask, position_ids
 
 
         self.model.eval()
         self.model.eval()
         with torch.no_grad():
         with torch.no_grad():
-            output = filling_sequence(
+            output = batch_filling_sequence(
                 self.model,
                 self.model,
-                seq,
+                seqs,
+                context_lengths,
                 get_masks_and_position_ids=get_masks_and_position_ids,
                 get_masks_and_position_ids=get_masks_and_position_ids,
-                batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
                 strategy=strategy,
                 strategy=strategy,
             )[0]
             )[0]
 
 
         if isinstance(output, torch.Tensor):  # different strategies
         if isinstance(output, torch.Tensor):  # different strategies
-            output = list(output)
+            output = output.tolist()
 
 
         output_targets = []
         output_targets = []
+        context_length = seqs.shape[1]
+        for lines in output:
+            lines = lines.tolist() if isinstance(lines, torch.Tensor) else lines
+            output_target = []
+            if not isinstance(lines, list):
+                lines = [lines]
+            for line in lines:
+                unfinished = line.index(-1) if -1 in line else len(line)
+                if line[unfinished - 1] in strategy.end_tokens:
+                    unfinished -= 1
+                line = line[context_length:unfinished]
+                output_target.append(line)
+            if not return_all_beams:
+                output_targets.append(output_target[0])
+            else:
+                output_targets.append(output_target)
+        return output_targets
 
 
-        for line in output:
-            line = line.tolist()
-            unfinished = line.index(-1) if -1 in line else len(line)
-            if line[unfinished - 1] in strategy.end_tokens:
-                unfinished -= 1
-            line = line[context_length:unfinished]
-            output_targets.append(line)
-
-        return output_targets if return_all_beams else output_targets[0]
 
 
     def calculate_loss(self, batch) -> List[float]:
     def calculate_loss(self, batch) -> List[float]:
         tokens, position_ids, attention_mask = self.process_data(batch)
         tokens, position_ids, attention_mask = self.process_data(batch)

+ 5 - 6
evaluation/tasks.py

@@ -9,10 +9,9 @@ from glob import glob
 from os.path import join, relpath
 from os.path import join, relpath
 from collections import defaultdict
 from collections import defaultdict
 
 
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 
 
-from generation import BeamSearchStrategy
+from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
 from .model import ModelForEvaluation
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
@@ -171,9 +170,11 @@ class GenerationTask(BaseTask, ABC):
 
 
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         if self.config.sampling_strategy == "BaseStrategy":
         if self.config.sampling_strategy == "BaseStrategy":
-            self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
+            self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
+                                         end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
             self.strategy = BeamSearchStrategy(
+                self.config.micro_batch_size,
                 self.config.num_beams,
                 self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
                 consider_end=True,
@@ -186,10 +187,8 @@ class GenerationTask(BaseTask, ABC):
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
 
 
     def predict_single_batch(self, batch) -> List[List[int]]:
     def predict_single_batch(self, batch) -> List[List[int]]:
-        # micro batch size = 1 for generation task,
-        # but we still need to return a list of predictions for consistency
         output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
         output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
-        return [output]
+        return output
 
 
 
 
 class MultiChoiceTask(BaseTask, ABC):
 class MultiChoiceTask(BaseTask, ABC):

+ 17 - 17
generate.py

@@ -7,9 +7,8 @@ from functools import partial
 from typing import List, Tuple
 from typing import List, Tuple
 
 
 from SwissArmyTransformer import mpu
 from SwissArmyTransformer import mpu
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
-from generation import BeamSearchStrategy
+from evaluation.model import batch_filling_sequence
+from generation import BeamSearchStrategy, BaseStrategy
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
 from initialize import initialize, initialize_model_and_tokenizer
 from initialize import initialize, initialize_model_and_tokenizer
 
 
@@ -31,16 +30,16 @@ def isEnglish(s):
         return True
         return True
 
 
 
 
-def get_masks_and_position_ids(seq, mask_position, context_length, gmask=False):
-    tokens = seq.unsqueeze(0)
-
-    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
+    context_length = seq.shape[1]
+    tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+    attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
     attention_mask.tril_()
     attention_mask.tril_()
     attention_mask[..., : context_length - 1] = 1
     attention_mask[..., : context_length - 1] = 1
     attention_mask.unsqueeze_(1)
     attention_mask.unsqueeze_(1)
     attention_mask = (attention_mask < 0.5).bool()
     attention_mask = (attention_mask < 0.5).bool()
 
 
-    position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
+    position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
     if not gmask:
     if not gmask:
         position_ids[context_length - 1 :] = mask_position
         position_ids[context_length - 1 :] = mask_position
 
 
@@ -99,30 +98,29 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
         output_list = []
         output_list = []
 
 
         input_seq = torch.cuda.LongTensor(
         input_seq = torch.cuda.LongTensor(
-            seq + [tokenizer.get_command("sop")] + [-1] * (args.out_seq_length - len(seq) - 1),
+            [seq + [tokenizer.get_command("sop")]],
             device=args.device,
             device=args.device,
         )
         )
-        output, _ = filling_sequence(
+        output, _ = batch_filling_sequence(
             model,
             model,
             input_seq,
             input_seq,
-            batch_size=num_output,
+            torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
             strategy=strategy,
             strategy=strategy,
-            log_attention_weights=None,
             get_masks_and_position_ids=partial(
             get_masks_and_position_ids=partial(
                 get_masks_and_position_ids,
                 get_masks_and_position_ids,
                 mask_position=mask_position,
                 mask_position=mask_position,
-                context_length=len(seq) + 1,
+                max_gen_length=args.out_seq_length - input_seq.shape[-1],
                 gmask=use_gmask,
                 gmask=use_gmask,
             ),
             ),
         )
         )
         if isinstance(output, torch.Tensor):  # different strategies
         if isinstance(output, torch.Tensor):  # different strategies
-            output = list(output)
-
+            output = output.tolist()
+        output = output[0]  # batch_size = 1
         output_list.extend(output)
         output_list.extend(output)
 
 
         # clip -1s and fill back generated things into seq
         # clip -1s and fill back generated things into seq
         for i in range(len(output_list)):
         for i in range(len(output_list)):
-            output = output_list[i].tolist()
+            output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
             try:
             try:
                 unfinished = output.index(-1)
                 unfinished = output.index(-1)
             except ValueError:
             except ValueError:
@@ -160,9 +158,11 @@ def main(args):
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
 
 
     if args.sampling_strategy == "BaseStrategy":
     if args.sampling_strategy == "BaseStrategy":
-        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens)
+        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                end_tokens=end_tokens)
     elif args.sampling_strategy == "BeamSearchStrategy":
     elif args.sampling_strategy == "BeamSearchStrategy":
         strategy = BeamSearchStrategy(
         strategy = BeamSearchStrategy(
+            1,
             args.num_beams,
             args.num_beams,
             length_penalty=args.length_penalty,
             length_penalty=args.length_penalty,
             consider_end=True,
             consider_end=True,

+ 1 - 1
generation/__init__.py

@@ -1 +1 @@
-from .strategies import BeamSearchStrategy
+from .strategies import BaseStrategy, BeamSearchStrategy

+ 122 - 61
generation/strategies.py

@@ -1,10 +1,56 @@
+import numpy as np
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
+from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
+
+class BaseStrategy:
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
+        self.batch_size = batch_size
+        self.invalid_slices = invalid_slices
+        self.temperature = temperature
+        self.topk = top_k
+        self.top_p = top_p
+        self.eps = eps
+        if end_tokens is None:
+            end_tokens = []
+        self.end_tokens = end_tokens
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
+
+    def forward(self, logits, tokens, mems, temperature=None):
+        logits = logits.view(-1, logits.size(-1))
+        batch_size = tokens.shape[0]
+        if temperature is None:
+            temperature = self.temperature
+        logits = logits / temperature
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -65504
+
+        logits = top_k_logits(logits, self.topk, self.top_p)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        pred = torch.multinomial(probs, num_samples=1)
+        for i in range(self.batch_size):
+            if i >= batch_size:
+                self._is_done[i] = True
+            elif self._is_done[i]:
+                pred[i] = -1
+            elif pred[i].item() in self.end_tokens:
+                self._is_done[i] = True
+        tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+        return tokens, mems
 
 
 
 
 class BeamSearchStrategy:
 class BeamSearchStrategy:
     def __init__(
     def __init__(
         self,
         self,
+        batch_size,
         num_beams,
         num_beams,
         length_penalty=1.0,
         length_penalty=1.0,
         consider_end=False,
         consider_end=False,
@@ -14,6 +60,7 @@ class BeamSearchStrategy:
         min_gen_length=0,
         min_gen_length=0,
         deterministic=False,
         deterministic=False,
     ):
     ):
+        self.batch_size = batch_size
         self.num_beams = num_beams
         self.num_beams = num_beams
         self.length_penalty = length_penalty
         self.length_penalty = length_penalty
         self.end_tokens = end_tokens
         self.end_tokens = end_tokens
@@ -25,26 +72,30 @@ class BeamSearchStrategy:
         self._init_cache()
         self._init_cache()
 
 
     def _init_cache(self):
     def _init_cache(self):
-        self.end_beams = []  # list of LongTensors
-        self.end_beams_penalized_scores = []  # list of LongTensors
+        self.end_beams = [[] for _ in range(self.batch_size)]  # list of LongTensors
+        self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)]  # list of LongTensors
         self.cached_beam_scores = 0  # [batch_size]
         self.cached_beam_scores = 0  # [batch_size]
-        self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
+        self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
         self.length_generated = 0
         self.length_generated = 0
-        self.is_done = False
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
 
-    def _add_end_beams(self, score, beam):
+    def _add_end_beams(self, score, beam, batch_idx):
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
-        for i in range(len(self.end_beams), -1, -1):
-            if i == 0 or score < self.end_beams_penalized_scores[i - 1]:
+        for i in range(len(self.end_beams[batch_idx]), -1, -1):
+            if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
                 break
                 break
-        self.end_beams.insert(i, beam)
-        self.end_beams_penalized_scores.insert(i, score)
+        self.end_beams[batch_idx].insert(i, beam)
+        self.end_beams_penalized_scores[batch_idx].insert(i, score)
 
 
-        self.end_beams = self.end_beams[: self.num_beams]
-        self.end_beams_penalized_scores = self.end_beams_penalized_scores[: self.num_beams]
+        self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
+        self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
 
 
     def forward(self, logits, tokens, mems):
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        batch_size, num_beams, vocab_size = logits.shape
         seq_len = tokens.shape[-1]
         seq_len = tokens.shape[-1]
         logits = logits.float()
         logits = logits.float()
         for invalid_slice in self.invalid_slices:
         for invalid_slice in self.invalid_slices:
@@ -53,79 +104,89 @@ class BeamSearchStrategy:
             for end_token in self.end_tokens:
             for end_token in self.end_tokens:
                 logits[..., end_token] = -65504
                 logits[..., end_token] = -65504
         if self.ngram > 0 and seq_len > self.ngram:
         if self.ngram > 0 and seq_len > self.ngram:
-            for i in range(batch_size):
-                ngram_prefix = tokens[i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
-                for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
-                    logits[i, banned_index] = -65504
+            for batch_idx in range(batch_size):
+                for i in range(num_beams):
+                    ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
+                    for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
+                        logits[batch_idx, i, banned_index] = -65504
 
 
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         prev_scores = self.cached_beam_scores
         prev_scores = self.cached_beam_scores
-        if isinstance(self.cached_beam_scores, torch.Tensor):
-            prev_scores = prev_scores[:, None].expand_as(next_token_scores)
+        if isinstance(prev_scores, torch.Tensor):
+            prev_scores = prev_scores[..., None].expand_as(next_token_scores)
         next_token_scores = next_token_scores + prev_scores
         next_token_scores = next_token_scores + prev_scores
 
 
-        next_token_scores = next_token_scores.view(batch_size * vocab_size)
+        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
 
-        probs = F.softmax(next_token_scores, dim=0)
+        probs = F.softmax(next_token_scores, dim=-1)
+        if num_beams < self.num_beams:  # First token
+            probs = probs[..., :vocab_size]
         if self.deterministic:
         if self.deterministic:
-            if mems.shape[1] < batch_size:  # First token
-                probs = probs[:vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else:
         else:
             next_tokens = torch.multinomial(
             next_tokens = torch.multinomial(
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
             )  # [2*nb]
             )  # [2*nb]
-        next_token_scores = next_token_scores[next_tokens]
-        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
-        next_tokens = next_tokens[_indices]
+        next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
 
 
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_tokens = next_tokens % vocab_size
         next_tokens = next_tokens % vocab_size
 
 
         # select out end beams or continue beams
         # select out end beams or continue beams
-        if mems.shape[1] < batch_size:
-            mems = mems.expand(-1, batch_size, -1, -1)
-        beam_continue = []
-        scores_continue = []
-        bans_continue = []
-        mems_contiue = []
-        for i in range(len(next_tokens)):
-            beam = torch.cat((tokens[next_indices[i]], next_tokens[i : i + 1]))
-            if int(next_tokens[i]) in self.end_tokens:
-                self._add_end_beams(next_token_scores[i], beam)
-            elif len(beam_continue) < self.num_beams:
-                beam_continue.append(beam)
-                mems_contiue.append(mems[:, next_indices[i]])
-                # update caches
-                scores_continue.append(next_token_scores[i])
-                if self.ngram > 0:
-                    bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
-                    ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram - 1) :].tolist())  # TODO ngram=1
-                    bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
-                    bans_continue.append(bans)
-            else:
-                break
-        tokens = torch.stack(beam_continue)
-        mems = torch.stack(mems_contiue, dim=1)
-        self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
-        self.cached_beam_ngram_bans = bans_continue
+        beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
+        for batch_idx in range(batch_size):
+            beam_continue = []
+            scores_continue = []
+            bans_continue = []
+            mems_contiue = []
+            for i in range(len(next_tokens[batch_idx])):
+                beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
+                if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
+                    self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
+                elif len(beam_continue) < self.num_beams:
+                    beam_continue.append(beam)
+                    mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
+                    # update caches
+                    scores_continue.append(next_token_scores[batch_idx, i])
+                    if self.ngram > 0:
+                        bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
+                        # TODO ngram=1
+                        ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
+                        bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
+                        bans_continue.append(bans)
+                else:
+                    break
+            beam_continue_batch.append(torch.stack(beam_continue))
+            mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
+            score_continue_batch.append(scores_continue)
+            self.cached_beam_ngram_bans[batch_idx] = bans_continue
+        tokens = torch.stack(beam_continue_batch)
+        mems = torch.stack(mems_continue_batch, dim=1)
+        self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
         self.length_generated += 1
         self.length_generated += 1
-
-        if (
-            len(self.end_beams) == self.num_beams
-            and self.end_beams_penalized_scores[-1]
-            >= self.cached_beam_scores.max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
-        ):  # We're done if none of current tokens will better than the worst in end_beams
-            self.is_done = True
+        for batch_idx in range(self.batch_size):
+            if batch_idx >= batch_size:
+                self._is_done[batch_idx] = True
+            elif (
+                len(self.end_beams[batch_idx]) == self.num_beams
+                and self.end_beams_penalized_scores[batch_idx][-1]
+                >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
+            ):  # We're done if none of current tokens will better than the worst in end_beams
+                self._is_done[batch_idx] = True
 
 
         return tokens, mems
         return tokens, mems
 
 
     def finalize(self, tokens, mems):
     def finalize(self, tokens, mems):
         if self.consider_end:
         if self.consider_end:
-            for i in range(tokens.shape[0]):
-                self._add_end_beams(self.cached_beam_scores[i], tokens[i])
+            batch_size, num_beams = tokens.shape[:2]
+            for batch_idx in range(batch_size):
+                if not self._is_done[batch_idx]:
+                    for i in range(num_beams):
+                        self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
             mems = None
             mems = None
-            ret = self.end_beams
+            ret = self.end_beams[:batch_size]
         else:
         else:
             ret = tokens
             ret = tokens
         self._init_cache()
         self._init_cache()

+ 4 - 3
tasks/lambada/strategy.py

@@ -7,7 +7,7 @@ class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
         self.banned_prefix = banned_prefix
         self.banned_prefix = banned_prefix
 
 
     def forward(self, logits, tokens, mems):
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        batch_size, num_beams, vocab_size = logits.shape
         logits = logits.float()
         logits = logits.float()
         for prefix in self.banned_prefix:
         for prefix in self.banned_prefix:
             if self.length_generated == len(prefix) - 1:
             if self.length_generated == len(prefix) - 1:
@@ -15,6 +15,7 @@ class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
                     logits[..., prefix[0]] = -65504
                     logits[..., prefix[0]] = -65504
                 else:
                 else:
                     for i in range(batch_size):
                     for i in range(batch_size):
-                        if tokens[i, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
-                            logits[i, prefix[-1]] = -65504
+                        for j in range(num_beams):
+                            if tokens[i, j, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
+                                logits[i, j, prefix[-1]] = -65504
         return super().forward(logits, tokens, mems)
         return super().forward(logits, tokens, mems)

+ 16 - 9
tasks/lambada/task.py

@@ -28,7 +28,8 @@ class LAMBADA(GenerationTask):
                     invalid_slices.append(pp[0])
                     invalid_slices.append(pp[0])
                 banned_prefix.append(pp)
                 banned_prefix.append(pp)
             self.strategy = BeamSearchStrategyForLAMBADA(
             self.strategy = BeamSearchStrategyForLAMBADA(
-                self.config.num_beams,
+                batch_size=self.config.micro_batch_size,
+                num_beams=self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
                 consider_end=True,
                 end_tokens=self.strategy.end_tokens,
                 end_tokens=self.strategy.end_tokens,
@@ -44,11 +45,17 @@ class LAMBADA(GenerationTask):
         return self.tokenizer.tokenize(text.split(" ")[0])
         return self.tokenizer.tokenize(text.split(" ")[0])
 
 
     def predict_single_batch(self, batch):
     def predict_single_batch(self, batch):
-        # micro batch size = 1 here, but we still need to return a list of predictions for consistency
-        outputs: List[List[int]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
-        for output in outputs:
-            text = self.tokenizer.tokenizer.decode(output).strip()
-            spl = text.split(" ")
-            if len(spl) >= 2 and spl[1] in punctuation:
-                return [self.get_first_word_tokens(output)]
-        return [self.get_first_word_tokens(outputs[0])]
+        outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
+        predictions = []
+        for outputs in outputs_batch:
+            found = False
+            for output in outputs:
+                text = self.tokenizer.tokenizer.decode(output).strip()
+                spl = text.split(" ")
+                if len(spl) >= 2 and spl[1] in punctuation:
+                    predictions.append(self.get_first_word_tokens(output))
+                    found = True
+                    break
+            if not found:
+                predictions.append(self.get_first_word_tokens(outputs[0]))
+        return predictions