Просмотр исходного кода

Merge pull request #13 from duzx16/batch-generation

Implement batch generation
Aohan Zeng 2 лет назад
Родитель
Сommit
335667a88d

+ 0 - 3
evaluation/configs.py

@@ -50,9 +50,6 @@ class GenerationTaskConfig(BaseConfig):
     min_gen_length: int = 0
     max_gen_length: int = 128
 
-    def __post_init__(self):
-        assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
-
 
 @dataclass
 class LanguageModelTaskConfig(BaseConfig):

+ 37 - 7
evaluation/dataset.py

@@ -84,6 +84,34 @@ class GenerationTaskDataset(EvaluationDataset):
             text = text[len(text) - text_length : len(text)]
         return {"text": text, "targets": targets}
 
+    @property
+    def has_collate_fn(self) -> bool:
+        return True
+
+    def collate_fn(self, samples):
+        TILE = 32
+        length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
+
+        token_batch, position_id_batch, attention_mask_batch = [], [], []
+        context_length_batch, target_position_id_batch = [], []
+
+        for sample in samples:
+            token, position_id, attention_mask = pad_batch(
+                sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
+            )
+            token_batch.append(token)
+            position_id_batch.append(position_id)
+            attention_mask_batch.append(attention_mask)
+            context_length_batch.append(sample['context_length'])
+            target_position_id_batch.append(sample['target_position_id'])
+        return {
+            "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
+            "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
+            "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
+            "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
+            "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
+        }
+
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
         tokenizer = get_tokenizer()
@@ -106,20 +134,22 @@ class GenerationTaskDataset(EvaluationDataset):
             else:
                 token = np.concatenate((token, [mask_id, sop_id]))
         context_length = len(token)
-        max_seq_length = context_length + max_gen_length
 
-        position_id = np.arange(0, max_seq_length, dtype=np.int64)
+        position_id = np.arange(0, context_length, dtype=np.int64)
+        target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
-            position_id[context_length - 1 :] = mask_position
+            position_id[context_length - 1:] = mask_position
+            target_position_id[:] = mask_position
 
-        attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
+        attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
         if not unidirectional:
             attention_mask[: context_length - 1, : context_length - 1] = 1
 
         item = {
-            "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
-            "position_ids": position_id,
-            "attention_mask": attention_mask < 0.5,
+            "token": token,
+            "position_id": position_id,
+            "target_position_id": target_position_id,
+            "attention_mask": attention_mask,
             "context_length": context_length,
         }
         return item

+ 110 - 21
evaluation/model.py

@@ -2,10 +2,79 @@ import torch
 
 from typing import List, Union
 
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
 from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 
 
+def batch_filling_sequence(
+        model,
+        seqs,
+        context_lengths,
+        strategy,
+        max_memory_length=100000,
+        get_masks_and_position_ids=get_masks_and_position_ids_default,
+        mems=None,
+        **kw_args
+        ):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
+            cache, should be first mems.shape[1] parts of context_tokens.
+            mems are the first-level citizens here, but we don't assume what is memorized.
+            input mems are used when multi-phase generation.
+    '''
+    assert len(seqs.shape) == 2
+
+    # building the initial tokens, attention_mask, and position_ids
+    batch_size, context_length = seqs.shape
+    seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
+    tokens = seqs[..., :context_length]
+    if attention_mask.dtype != torch.bool:
+        attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
+    # initialize generation
+    counter = context_length - 1 # Last fixed index is ``counter''
+    index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
+    num_beams = 1
+    # step-by-step generation
+    while counter < seqs.shape[1] - 1:
+        # Now, we want to generate seq[counter + 1],
+        # token[:, index: counter+1] needs forwarding.
+        # forward
+        tokens = tokens.reshape(batch_size * num_beams, -1)
+        mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
+        logits, *output_per_layers = model(
+            tokens[:, index:],
+            position_ids[..., index: counter+1],
+            attention_mask[..., index: counter+1, :counter+1], # TODO memlen
+            mems=mems,
+            **kw_args
+        )
+        mem_kv = [o['mem_kv'] for o in output_per_layers]
+        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
+        if counter == context_length - 1:
+            logits = logits[torch.arange(batch_size), context_lengths - 1]
+        else:
+            logits = logits[:, -1]
+        counter += 1
+        index = counter
+        # if torch.distributed.get_rank() == 0:
+        #     print(f"counter: {counter}: logits: {logits.float().abs().mean()}")
+        # sampling
+        logits = logits.reshape(batch_size, num_beams, -1)
+        tokens = tokens.reshape(batch_size, num_beams, -1)
+        mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
+        tokens, mems = strategy.forward(logits, tokens, mems)
+        if len(tokens.shape) == 3 and num_beams == 1:
+            num_beams = tokens.shape[1]
+            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            attention_mask_shape = attention_mask.shape[-3:]
+            attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
+                batch_size * num_beams, *attention_mask_shape)
+        if strategy.is_done:
+            break
+    return strategy.finalize(tokens, mems)
+
+
 class ModelForEvaluation(torch.nn.Module):
     def __init__(self, model):
         super().__init__()
@@ -48,45 +117,65 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
         return log_probs
 
-    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
+    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
+        List[List[int]], List[List[List[int]]]]:
         """
         @return: A list of text model generated, sorted by score in descending order
         """
 
-        seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
-        context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
-        seq[context_length:] = -1
+        seqs = sample["tokens"].to(device=torch.cuda.current_device()).long()
+        context_lengths = sample["context_length"].long()
 
         def get_masks_and_position_ids(seq):
-            tokens = seq.unsqueeze(0)
-            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
-            position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
+            batch_size = seq.shape[0]
+            max_gen_length = sample['target_position_ids'].shape[-1]
+            tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+            position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
+            position_ids = position_ids.to(device=torch.cuda.current_device()).long()
+            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device())
+            context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
+                                                                                                             max_gen_length,
+                                                                                                             1)
+            causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
+            generation_mask = torch.cat(
+                (context_mask, causal_mask), dim=-1)
+            attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode='constant', value=1)
+            attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
+            attention_mask = attention_mask.bool().unsqueeze(1)
             return tokens, attention_mask, position_ids
 
         self.model.eval()
         with torch.no_grad():
-            output = filling_sequence(
+            output = batch_filling_sequence(
                 self.model,
-                seq,
+                seqs,
+                context_lengths,
                 get_masks_and_position_ids=get_masks_and_position_ids,
-                batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
                 strategy=strategy,
             )[0]
 
         if isinstance(output, torch.Tensor):  # different strategies
-            output = list(output)
+            output = output.tolist()
 
         output_targets = []
+        context_length = seqs.shape[1]
+        for lines in output:
+            lines = lines.tolist() if isinstance(lines, torch.Tensor) else lines
+            output_target = []
+            if not isinstance(lines, list):
+                lines = [lines]
+            for line in lines:
+                unfinished = line.index(-1) if -1 in line else len(line)
+                if line[unfinished - 1] in strategy.end_tokens:
+                    unfinished -= 1
+                line = line[context_length:unfinished]
+                output_target.append(line)
+            if not return_all_beams:
+                output_targets.append(output_target[0])
+            else:
+                output_targets.append(output_target)
+        return output_targets
 
-        for line in output:
-            line = line.tolist()
-            unfinished = line.index(-1) if -1 in line else len(line)
-            if line[unfinished - 1] in strategy.end_tokens:
-                unfinished -= 1
-            line = line[context_length:unfinished]
-            output_targets.append(line)
-
-        return output_targets if return_all_beams else output_targets[0]
 
     def calculate_loss(self, batch) -> List[float]:
         tokens, position_ids, attention_mask = self.process_data(batch)

+ 5 - 6
evaluation/tasks.py

@@ -9,10 +9,9 @@ from glob import glob
 from os.path import join, relpath
 from collections import defaultdict
 
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 
-from generation import BeamSearchStrategy
+from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
@@ -171,9 +170,11 @@ class GenerationTask(BaseTask, ABC):
 
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         if self.config.sampling_strategy == "BaseStrategy":
-            self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
+            self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
+                                         end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
+                self.config.micro_batch_size,
                 self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
@@ -186,10 +187,8 @@ class GenerationTask(BaseTask, ABC):
             raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
 
     def predict_single_batch(self, batch) -> List[List[int]]:
-        # micro batch size = 1 for generation task,
-        # but we still need to return a list of predictions for consistency
         output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
-        return [output]
+        return output
 
 
 class MultiChoiceTask(BaseTask, ABC):

+ 17 - 17
generate.py

@@ -7,9 +7,8 @@ from functools import partial
 from typing import List, Tuple
 
 from SwissArmyTransformer import mpu
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
-from generation import BeamSearchStrategy
+from evaluation.model import batch_filling_sequence
+from generation import BeamSearchStrategy, BaseStrategy
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
 from initialize import initialize, initialize_model_and_tokenizer
 
@@ -31,16 +30,16 @@ def isEnglish(s):
         return True
 
 
-def get_masks_and_position_ids(seq, mask_position, context_length, gmask=False):
-    tokens = seq.unsqueeze(0)
-
-    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
+    context_length = seq.shape[1]
+    tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+    attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
     attention_mask.tril_()
     attention_mask[..., : context_length - 1] = 1
     attention_mask.unsqueeze_(1)
     attention_mask = (attention_mask < 0.5).bool()
 
-    position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
+    position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
     if not gmask:
         position_ids[context_length - 1 :] = mask_position
 
@@ -99,30 +98,29 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
         output_list = []
 
         input_seq = torch.cuda.LongTensor(
-            seq + [tokenizer.get_command("sop")] + [-1] * (args.out_seq_length - len(seq) - 1),
+            [seq + [tokenizer.get_command("sop")]],
             device=args.device,
         )
-        output, _ = filling_sequence(
+        output, _ = batch_filling_sequence(
             model,
             input_seq,
-            batch_size=num_output,
+            torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
             strategy=strategy,
-            log_attention_weights=None,
             get_masks_and_position_ids=partial(
                 get_masks_and_position_ids,
                 mask_position=mask_position,
-                context_length=len(seq) + 1,
+                max_gen_length=args.out_seq_length - input_seq.shape[-1],
                 gmask=use_gmask,
             ),
         )
         if isinstance(output, torch.Tensor):  # different strategies
-            output = list(output)
-
+            output = output.tolist()
+        output = output[0]  # batch_size = 1
         output_list.extend(output)
 
         # clip -1s and fill back generated things into seq
         for i in range(len(output_list)):
-            output = output_list[i].tolist()
+            output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
             try:
                 unfinished = output.index(-1)
             except ValueError:
@@ -160,9 +158,11 @@ def main(args):
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
 
     if args.sampling_strategy == "BaseStrategy":
-        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens)
+        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                end_tokens=end_tokens)
     elif args.sampling_strategy == "BeamSearchStrategy":
         strategy = BeamSearchStrategy(
+            1,
             args.num_beams,
             length_penalty=args.length_penalty,
             consider_end=True,

+ 1 - 1
generation/__init__.py

@@ -1 +1 @@
-from .strategies import BeamSearchStrategy
+from .strategies import BaseStrategy, BeamSearchStrategy

+ 122 - 61
generation/strategies.py

@@ -1,10 +1,56 @@
+import numpy as np
 import torch
 import torch.nn.functional as F
+from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
+
+class BaseStrategy:
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
+        self.batch_size = batch_size
+        self.invalid_slices = invalid_slices
+        self.temperature = temperature
+        self.topk = top_k
+        self.top_p = top_p
+        self.eps = eps
+        if end_tokens is None:
+            end_tokens = []
+        self.end_tokens = end_tokens
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
+
+    def forward(self, logits, tokens, mems, temperature=None):
+        logits = logits.view(-1, logits.size(-1))
+        batch_size = tokens.shape[0]
+        if temperature is None:
+            temperature = self.temperature
+        logits = logits / temperature
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -65504
+
+        logits = top_k_logits(logits, self.topk, self.top_p)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        pred = torch.multinomial(probs, num_samples=1)
+        for i in range(self.batch_size):
+            if i >= batch_size:
+                self._is_done[i] = True
+            elif self._is_done[i]:
+                pred[i] = -1
+            elif pred[i].item() in self.end_tokens:
+                self._is_done[i] = True
+        tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+        return tokens, mems
 
 
 class BeamSearchStrategy:
     def __init__(
         self,
+        batch_size,
         num_beams,
         length_penalty=1.0,
         consider_end=False,
@@ -14,6 +60,7 @@ class BeamSearchStrategy:
         min_gen_length=0,
         deterministic=False,
     ):
+        self.batch_size = batch_size
         self.num_beams = num_beams
         self.length_penalty = length_penalty
         self.end_tokens = end_tokens
@@ -25,26 +72,30 @@ class BeamSearchStrategy:
         self._init_cache()
 
     def _init_cache(self):
-        self.end_beams = []  # list of LongTensors
-        self.end_beams_penalized_scores = []  # list of LongTensors
+        self.end_beams = [[] for _ in range(self.batch_size)]  # list of LongTensors
+        self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)]  # list of LongTensors
         self.cached_beam_scores = 0  # [batch_size]
-        self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
+        self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
         self.length_generated = 0
-        self.is_done = False
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
-    def _add_end_beams(self, score, beam):
+    def _add_end_beams(self, score, beam, batch_idx):
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
-        for i in range(len(self.end_beams), -1, -1):
-            if i == 0 or score < self.end_beams_penalized_scores[i - 1]:
+        for i in range(len(self.end_beams[batch_idx]), -1, -1):
+            if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
                 break
-        self.end_beams.insert(i, beam)
-        self.end_beams_penalized_scores.insert(i, score)
+        self.end_beams[batch_idx].insert(i, beam)
+        self.end_beams_penalized_scores[batch_idx].insert(i, score)
 
-        self.end_beams = self.end_beams[: self.num_beams]
-        self.end_beams_penalized_scores = self.end_beams_penalized_scores[: self.num_beams]
+        self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
+        self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
 
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        batch_size, num_beams, vocab_size = logits.shape
         seq_len = tokens.shape[-1]
         logits = logits.float()
         for invalid_slice in self.invalid_slices:
@@ -53,79 +104,89 @@ class BeamSearchStrategy:
             for end_token in self.end_tokens:
                 logits[..., end_token] = -65504
         if self.ngram > 0 and seq_len > self.ngram:
-            for i in range(batch_size):
-                ngram_prefix = tokens[i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
-                for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
-                    logits[i, banned_index] = -65504
+            for batch_idx in range(batch_size):
+                for i in range(num_beams):
+                    ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
+                    for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
+                        logits[batch_idx, i, banned_index] = -65504
 
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         prev_scores = self.cached_beam_scores
-        if isinstance(self.cached_beam_scores, torch.Tensor):
-            prev_scores = prev_scores[:, None].expand_as(next_token_scores)
+        if isinstance(prev_scores, torch.Tensor):
+            prev_scores = prev_scores[..., None].expand_as(next_token_scores)
         next_token_scores = next_token_scores + prev_scores
 
-        next_token_scores = next_token_scores.view(batch_size * vocab_size)
+        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
-        probs = F.softmax(next_token_scores, dim=0)
+        probs = F.softmax(next_token_scores, dim=-1)
+        if num_beams < self.num_beams:  # First token
+            probs = probs[..., :vocab_size]
         if self.deterministic:
-            if mems.shape[1] < batch_size:  # First token
-                probs = probs[:vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else:
             next_tokens = torch.multinomial(
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
             )  # [2*nb]
-        next_token_scores = next_token_scores[next_tokens]
-        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
-        next_tokens = next_tokens[_indices]
+        next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
 
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_tokens = next_tokens % vocab_size
 
         # select out end beams or continue beams
-        if mems.shape[1] < batch_size:
-            mems = mems.expand(-1, batch_size, -1, -1)
-        beam_continue = []
-        scores_continue = []
-        bans_continue = []
-        mems_contiue = []
-        for i in range(len(next_tokens)):
-            beam = torch.cat((tokens[next_indices[i]], next_tokens[i : i + 1]))
-            if int(next_tokens[i]) in self.end_tokens:
-                self._add_end_beams(next_token_scores[i], beam)
-            elif len(beam_continue) < self.num_beams:
-                beam_continue.append(beam)
-                mems_contiue.append(mems[:, next_indices[i]])
-                # update caches
-                scores_continue.append(next_token_scores[i])
-                if self.ngram > 0:
-                    bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
-                    ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram - 1) :].tolist())  # TODO ngram=1
-                    bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
-                    bans_continue.append(bans)
-            else:
-                break
-        tokens = torch.stack(beam_continue)
-        mems = torch.stack(mems_contiue, dim=1)
-        self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
-        self.cached_beam_ngram_bans = bans_continue
+        beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
+        for batch_idx in range(batch_size):
+            beam_continue = []
+            scores_continue = []
+            bans_continue = []
+            mems_contiue = []
+            for i in range(len(next_tokens[batch_idx])):
+                beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
+                if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
+                    self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
+                elif len(beam_continue) < self.num_beams:
+                    beam_continue.append(beam)
+                    mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
+                    # update caches
+                    scores_continue.append(next_token_scores[batch_idx, i])
+                    if self.ngram > 0:
+                        bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
+                        # TODO ngram=1
+                        ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
+                        bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
+                        bans_continue.append(bans)
+                else:
+                    break
+            beam_continue_batch.append(torch.stack(beam_continue))
+            mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
+            score_continue_batch.append(scores_continue)
+            self.cached_beam_ngram_bans[batch_idx] = bans_continue
+        tokens = torch.stack(beam_continue_batch)
+        mems = torch.stack(mems_continue_batch, dim=1)
+        self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
         self.length_generated += 1
-
-        if (
-            len(self.end_beams) == self.num_beams
-            and self.end_beams_penalized_scores[-1]
-            >= self.cached_beam_scores.max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
-        ):  # We're done if none of current tokens will better than the worst in end_beams
-            self.is_done = True
+        for batch_idx in range(self.batch_size):
+            if batch_idx >= batch_size:
+                self._is_done[batch_idx] = True
+            elif (
+                len(self.end_beams[batch_idx]) == self.num_beams
+                and self.end_beams_penalized_scores[batch_idx][-1]
+                >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
+            ):  # We're done if none of current tokens will better than the worst in end_beams
+                self._is_done[batch_idx] = True
 
         return tokens, mems
 
     def finalize(self, tokens, mems):
         if self.consider_end:
-            for i in range(tokens.shape[0]):
-                self._add_end_beams(self.cached_beam_scores[i], tokens[i])
+            batch_size, num_beams = tokens.shape[:2]
+            for batch_idx in range(batch_size):
+                if not self._is_done[batch_idx]:
+                    for i in range(num_beams):
+                        self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
             mems = None
-            ret = self.end_beams
+            ret = self.end_beams[:batch_size]
         else:
             ret = tokens
         self._init_cache()

+ 4 - 3
tasks/lambada/strategy.py

@@ -7,7 +7,7 @@ class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
         self.banned_prefix = banned_prefix
 
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        batch_size, num_beams, vocab_size = logits.shape
         logits = logits.float()
         for prefix in self.banned_prefix:
             if self.length_generated == len(prefix) - 1:
@@ -15,6 +15,7 @@ class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
                     logits[..., prefix[0]] = -65504
                 else:
                     for i in range(batch_size):
-                        if tokens[i, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
-                            logits[i, prefix[-1]] = -65504
+                        for j in range(num_beams):
+                            if tokens[i, j, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
+                                logits[i, j, prefix[-1]] = -65504
         return super().forward(logits, tokens, mems)

+ 16 - 9
tasks/lambada/task.py

@@ -28,7 +28,8 @@ class LAMBADA(GenerationTask):
                     invalid_slices.append(pp[0])
                 banned_prefix.append(pp)
             self.strategy = BeamSearchStrategyForLAMBADA(
-                self.config.num_beams,
+                batch_size=self.config.micro_batch_size,
+                num_beams=self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
                 end_tokens=self.strategy.end_tokens,
@@ -44,11 +45,17 @@ class LAMBADA(GenerationTask):
         return self.tokenizer.tokenize(text.split(" ")[0])
 
     def predict_single_batch(self, batch):
-        # micro batch size = 1 here, but we still need to return a list of predictions for consistency
-        outputs: List[List[int]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
-        for output in outputs:
-            text = self.tokenizer.tokenizer.decode(output).strip()
-            spl = text.split(" ")
-            if len(spl) >= 2 and spl[1] in punctuation:
-                return [self.get_first_word_tokens(output)]
-        return [self.get_first_word_tokens(outputs[0])]
+        outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
+        predictions = []
+        for outputs in outputs_batch:
+            found = False
+            for output in outputs:
+                text = self.tokenizer.tokenizer.decode(output).strip()
+                spl = text.split(" ")
+                if len(spl) >= 2 and spl[1] in punctuation:
+                    predictions.append(self.get_first_word_tokens(output))
+                    found = True
+                    break
+            if not found:
+                predictions.append(self.get_first_word_tokens(outputs[0]))
+        return predictions