Переглянути джерело

Implement batch generation
beam search, use_task_mask and unidirectional not tested

Zhengxiao Du 3 роки тому
батько
коміт
bb9fbe4bfc
6 змінених файлів з 263 додано та 86 видалено
  1. 2 1
      evaluation/configs.py
  2. 37 7
      evaluation/dataset.py
  3. 95 13
      evaluation/model.py
  4. 7 5
      evaluation/tasks.py
  5. 1 1
      generation/__init__.py
  6. 121 59
      generation/strategies.py

+ 2 - 1
evaluation/configs.py

@@ -50,4 +50,5 @@ class GenerationTaskConfig(BaseConfig):
     max_gen_length: int = 128
 
     def __post_init__(self):
-        assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
+        pass
+        # assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"

+ 37 - 7
evaluation/dataset.py

@@ -76,6 +76,34 @@ class GenerationTaskDataset(EvaluationDataset):
             text = text[len(text) - text_length : len(text)]
         return {"text": text, "targets": targets}
 
+    @property
+    def has_collate_fn(self) -> bool:
+        return True
+
+    def collate_fn(self, samples):
+        TILE = 32
+        length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
+
+        token_batch, position_id_batch, attention_mask_batch = [], [], []
+        context_length_batch, target_position_id_batch = [], []
+
+        for sample in samples:
+            token, position_id, attention_mask = pad_batch(
+                sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
+            )
+            token_batch.append(token)
+            position_id_batch.append(position_id)
+            attention_mask_batch.append(attention_mask)
+            context_length_batch.append(sample['context_length'])
+            target_position_id_batch.append(sample['target_position_id'])
+        return {
+            "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
+            "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
+            "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
+            "context_length": torch.tensor(context_length_batch, dtype=torch.int64),
+            "target_position_ids": torch.tensor(np.array(target_position_id_batch), dtype=torch.int64),
+        }
+
     @staticmethod
     def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
         tokenizer = get_tokenizer()
@@ -98,20 +126,22 @@ class GenerationTaskDataset(EvaluationDataset):
             else:
                 token = np.concatenate((token, [mask_id, sop_id]))
         context_length = len(token)
-        max_seq_length = context_length + max_gen_length
 
-        position_id = np.arange(0, max_seq_length, dtype=np.int64)
+        position_id = np.arange(0, context_length, dtype=np.int64)
+        target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
-            position_id[context_length - 1 :] = mask_position
+            position_id[context_length - 1:] = mask_position
+            target_position_id[:] = mask_position
 
-        attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
+        attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
         if not unidirectional:
             attention_mask[: context_length - 1, : context_length - 1] = 1
 
         item = {
-            "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
-            "position_ids": position_id,
-            "attention_mask": attention_mask < 0.5,
+            "token": token,
+            "position_id": position_id,
+            "target_position_id": target_position_id,
+            "attention_mask": attention_mask,
             "context_length": context_length,
         }
         return item

+ 95 - 13
evaluation/model.py

@@ -2,7 +2,78 @@ import torch
 
 from typing import List, Union
 
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
+
+
+def batch_filling_sequence(
+        model,
+        seqs,
+        context_lengths,
+        strategy,
+        max_memory_length=100000,
+        get_masks_and_position_ids=get_masks_and_position_ids_default,
+        mems=None,
+        **kw_args
+        ):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
+            cache, should be first mems.shape[1] parts of context_tokens.
+            mems are the first-level citizens here, but we don't assume what is memorized.
+            input mems are used when multi-phase generation.
+    '''
+    assert len(seqs.shape) == 2
+
+    # building the initial tokens, attention_mask, and position_ids
+    batch_size, context_length = seqs.shape
+    seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
+    tokens = seqs[..., :context_length]
+    if attention_mask.dtype != torch.bool:
+        attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
+    # initialize generation
+    counter = context_length - 1 # Last fixed index is ``counter''
+    index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
+    num_beams = 1
+    # step-by-step generation
+    while counter < seqs.shape[1] - 1:
+        # Now, we want to generate seq[counter + 1],
+        # token[:, index: counter+1] needs forwarding.
+        # forward
+        if num_beams > 1:
+            tokens = tokens.reshape(batch_size * num_beams, -1)
+            mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
+        logits, *output_per_layers = model(
+            tokens[:, index:],
+            position_ids[..., index: counter+1],
+            attention_mask[..., index: counter+1, :counter+1], # TODO memlen
+            mems=mems,
+            **kw_args
+        )
+        mem_kv = [o['mem_kv'] for o in output_per_layers]
+        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
+        if counter == context_length - 1:
+            logits = logits[torch.arange(batch_size), context_lengths - 1]
+        else:
+            logits = logits[:, -1]
+        # if torch.distributed.get_rank() == 0:
+        #     breakpoint()
+        # torch.distributed.barrier()
+        counter += 1
+        index = counter
+        # sampling
+        if num_beams > 1:
+            logits = logits.reshape(batch_size, num_beams, -1)
+            tokens = tokens.reshape(batch_size, num_beams, -1)
+            mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
+        tokens, mems = strategy.forward(logits, tokens, mems)
+        if len(tokens.shape) == 3:
+            num_beams = tokens.shape[1]
+            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
+                batch_size * num_beams, -1, -1, -1)
+        if strategy.is_done:
+            break
+    return strategy.finalize(tokens, mems)
 
 
 class ModelForEvaluation(torch.nn.Module):
@@ -47,28 +118,39 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
         return log_probs
 
-    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
+    def generate_text(self, sample, strategy, return_all_beams=False, max_gen_length=128) -> Union[
+        List[int], List[List[int]]]:
         """
         @return: A list of text model generated, sorted by score in descending order
         """
 
-        seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
-        context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
-        seq[context_length:] = -1
+        seqs = sample["tokens"].to(device=torch.cuda.current_device()).long()
+        context_lengths = sample["context_length"].long()
 
         def get_masks_and_position_ids(seq):
-            tokens = seq.unsqueeze(0)
-            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
-            position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
+            batch_size = seq.shape[0]
+            tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+            position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
+            position_ids = position_ids.to(device=torch.cuda.current_device()).long()
+            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device())
+            context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
+                                                                                                             max_gen_length,
+                                                                                                             1)
+            causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
+            generation_mask = torch.cat(
+                (context_mask, causal_mask), dim=-1)
+            attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode='constant', value=1)
+            attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
+            attention_mask = attention_mask.bool().unsqueeze(1)
             return tokens, attention_mask, position_ids
 
         self.model.eval()
         with torch.no_grad():
-            output = filling_sequence(
+            output = batch_filling_sequence(
                 self.model,
-                seq,
+                seqs,
+                context_lengths,
                 get_masks_and_position_ids=get_masks_and_position_ids,
-                batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
                 strategy=strategy,
             )[0]
 
@@ -76,7 +158,7 @@ class ModelForEvaluation(torch.nn.Module):
             output = list(output)
 
         output_targets = []
-
+        context_length = seqs.shape[1]
         for line in output:
             line = line.tolist()
             unfinished = line.index(-1) if -1 in line else len(line)
@@ -85,4 +167,4 @@ class ModelForEvaluation(torch.nn.Module):
             line = line[context_length:unfinished]
             output_targets.append(line)
 
-        return output_targets if return_all_beams else output_targets[0]
+        return output_targets

+ 7 - 5
evaluation/tasks.py

@@ -9,10 +9,9 @@ from glob import glob
 from os.path import join, relpath
 from collections import defaultdict
 
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 
-from generation import BeamSearchStrategy
+from generation import BaseStrategy, BeamSearchStrategy
 from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig
 from .model import ModelForEvaluation
 from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset
@@ -171,9 +170,11 @@ class GenerationTask(BaseTask, ABC):
 
         end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
         if self.config.sampling_strategy == "BaseStrategy":
-            self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
+            self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
+                                         end_tokens=end_tokens)
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
+                self.config.micro_batch_size,
                 self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
@@ -188,8 +189,9 @@ class GenerationTask(BaseTask, ABC):
     def predict_single_batch(self, batch) -> List[List[int]]:
         # micro batch size = 1 for generation task,
         # but we still need to return a list of predictions for consistency
-        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
-        return [output]
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False,
+                                          max_gen_length=self.config.max_gen_length)
+        return output
 
 
 class MultiChoiceTask(BaseTask, ABC):

+ 1 - 1
generation/__init__.py

@@ -1 +1 @@
-from .strategies import BeamSearchStrategy
+from .strategies import BaseStrategy, BeamSearchStrategy

+ 121 - 59
generation/strategies.py

@@ -1,10 +1,56 @@
+import numpy as np
 import torch
 import torch.nn.functional as F
+from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
+
+class BaseStrategy:
+    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
+        self.batch_size = batch_size
+        self.invalid_slices = invalid_slices
+        self.temperature = temperature
+        self.topk = top_k
+        self.top_p = top_p
+        self.eps = eps
+        if end_tokens is None:
+            end_tokens = []
+        self.end_tokens = end_tokens
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done.all()
+
+    def forward(self, logits, tokens, mems, temperature=None):
+        batch_size = tokens.shape[0]
+        if temperature is None:
+            temperature = self.temperature
+        logits = logits / temperature
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -65504
+
+        # logits = top_k_logits(logits, self.topk, self.top_p)
+        # probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        # pred = torch.multinomial(probs, num_samples=1)
+        pred = torch.argmax(logits, dim=-1)
+        for i in range(self.batch_size):
+            if i >= batch_size:
+                self._is_done[i] = True
+            elif self._is_done[i]:
+                pred[i] = -1
+            elif pred[i].item() in self.end_tokens:
+                self._is_done[i] = True
+        tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
+        return tokens, mems
 
 
 class BeamSearchStrategy:
     def __init__(
         self,
+        batch_size,
         num_beams,
         length_penalty=1.0,
         consider_end=False,
@@ -14,6 +60,7 @@ class BeamSearchStrategy:
         min_gen_length=0,
         deterministic=False,
     ):
+        self.batch_size = batch_size
         self.num_beams = num_beams
         self.length_penalty = length_penalty
         self.end_tokens = end_tokens
@@ -25,26 +72,30 @@ class BeamSearchStrategy:
         self._init_cache()
 
     def _init_cache(self):
-        self.end_beams = []  # list of LongTensors
-        self.end_beams_penalized_scores = []  # list of LongTensors
+        self.end_beams = [[] for _ in range(self.batch_size)]  # list of LongTensors
+        self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)]  # list of LongTensors
         self.cached_beam_scores = 0  # [batch_size]
-        self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
+        self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
         self.length_generated = 0
-        self.is_done = False
+        self._is_done = np.zeros(self.batch_size, dtype=np.bool)
 
-    def _add_end_beams(self, score, beam):
+    def _add_end_beams(self, score, beam, batch_idx):
         score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
         for i in range(len(self.end_beams), -1, -1):
-            if i == 0 or score < self.end_beams_penalized_scores[i - 1]:
+            if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
                 break
-        self.end_beams.insert(i, beam)
-        self.end_beams_penalized_scores.insert(i, score)
+        self.end_beams[batch_idx].insert(i, beam)
+        self.end_beams_penalized_scores[batch_idx].insert(i, score)
 
-        self.end_beams = self.end_beams[: self.num_beams]
-        self.end_beams_penalized_scores = self.end_beams_penalized_scores[: self.num_beams]
+        self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
+        self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
 
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        if len(logits.shape) == 2:
+            logits = logits.unsqueeze(1)
+            tokens = tokens.unsqueeze(1)
+            mems = mems.unsqueeze(2)
+        batch_size, num_beams, vocab_size = logits.shape
         seq_len = tokens.shape[-1]
         logits = logits.float()
         for invalid_slice in self.invalid_slices:
@@ -53,77 +104,88 @@ class BeamSearchStrategy:
             for end_token in self.end_tokens:
                 logits[..., end_token] = -65504
         if self.ngram > 0 and seq_len > self.ngram:
-            for i in range(batch_size):
-                ngram_prefix = tokens[i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
-                for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
-                    logits[i, banned_index] = -65504
+            for batch_idx in range(batch_size):
+                for i in range(num_beams):
+                    ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
+                    for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
+                        logits[batch_idx, i, banned_index] = -65504
 
         next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
         prev_scores = self.cached_beam_scores
-        if isinstance(self.cached_beam_scores, torch.Tensor):
-            prev_scores = prev_scores[:, None].expand_as(next_token_scores)
+        if isinstance(prev_scores, torch.Tensor):
+            prev_scores = prev_scores[..., None].expand_as(next_token_scores)
         next_token_scores = next_token_scores + prev_scores
 
-        next_token_scores = next_token_scores.view(batch_size * vocab_size)
+        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
-        probs = F.softmax(next_token_scores, dim=0)
+        probs = F.softmax(next_token_scores, dim=-1)
         if self.deterministic:
-            if mems.shape[1] < batch_size:  # First token
-                probs = probs[:vocab_size]
+            if mems.shape[2] < self.num_beams:  # First token
+                probs = probs[..., :vocab_size]
             next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
         else:
             next_tokens = torch.multinomial(
                 probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
             )  # [2*nb]
-        next_token_scores = next_token_scores[next_tokens]
-        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
-        next_tokens = next_tokens[_indices]
+        next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
 
         next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
         next_tokens = next_tokens % vocab_size
 
         # select out end beams or continue beams
-        if mems.shape[1] < batch_size:
-            mems = mems.expand(-1, batch_size, -1, -1)
-        beam_continue = []
-        scores_continue = []
-        bans_continue = []
-        mems_contiue = []
-        for i in range(len(next_tokens)):
-            beam = torch.cat((tokens[next_indices[i]], next_tokens[i : i + 1]))
-            if int(next_tokens[i]) in self.end_tokens:
-                self._add_end_beams(next_token_scores[i], beam)
-            elif len(beam_continue) < self.num_beams:
-                beam_continue.append(beam)
-                mems_contiue.append(mems[:, next_indices[i]])
-                # update caches
-                scores_continue.append(next_token_scores[i])
-                if self.ngram > 0:
-                    bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
-                    ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram - 1) :].tolist())  # TODO ngram=1
-                    bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
-                    bans_continue.append(bans)
-            else:
-                break
-        tokens = torch.stack(beam_continue)
-        mems = torch.stack(mems_contiue, dim=1)
-        self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
-        self.cached_beam_ngram_bans = bans_continue
+        if mems.shape[2] < self.num_beams:
+            mems = mems.expand(-1, batch_size, self.num_beams, -1, -1)
+        beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
+        for batch_idx in range(batch_size):
+            beam_continue = []
+            scores_continue = []
+            bans_continue = []
+            mems_contiue = []
+            for i in range(len(next_tokens[batch_idx])):
+                beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
+                if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
+                    self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
+                elif len(beam_continue) < self.num_beams:
+                    beam_continue.append(beam)
+                    mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
+                    # update caches
+                    scores_continue.append(next_token_scores[batch_idx, i])
+                    if self.ngram > 0:
+                        bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
+                        # TODO ngram=1
+                        ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
+                        bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
+                        bans_continue.append(bans)
+                else:
+                    break
+            beam_continue_batch.append(torch.stack(beam_continue))
+            mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
+            score_continue_batch.append(scores_continue)
+            self.cached_beam_ngram_bans[batch_idx] = bans_continue
+        tokens = torch.stack(beam_continue_batch)
+        mems = torch.stack(mems_continue_batch)
+        self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
         self.length_generated += 1
-
-        if (
-            len(self.end_beams) == self.num_beams
-            and self.end_beams_penalized_scores[-1]
-            >= self.cached_beam_scores.max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
-        ):  # We're done if none of current tokens will better than the worst in end_beams
-            self.is_done = True
+        for batch_idx in range(self.batch_size):
+            if batch_idx >= batch_size:
+                self._is_done[batch_idx] = True
+            elif (
+                len(self.end_beams[batch_idx]) == self.num_beams
+                and self.end_beams_penalized_scores[batch_idx][-1]
+                >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
+            ):  # We're done if none of current tokens will better than the worst in end_beams
+                self._is_done[batch_idx] = True
 
         return tokens, mems
 
     def finalize(self, tokens, mems):
         if self.consider_end:
-            for i in range(tokens.shape[0]):
-                self._add_end_beams(self.cached_beam_scores[i], tokens[i])
+            for batch_idx in range(tokens.shape[0]):
+                if not self._is_done[batch_idx]:
+                    for i in range(tokens.shape[0]):
+                        self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[i], batch_idx)
             mems = None
             ret = self.end_beams
         else: