123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- import numpy as np
- import torch
- from typing import List, Union
- from scipy.linalg import block_diag
- from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
- from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
- from SwissArmyTransformer import get_tokenizer
- def batch_filling_sequence(
- model,
- seqs,
- context_lengths,
- strategy,
- max_memory_length=100000,
- get_masks_and_position_ids=get_masks_and_position_ids_default,
- mems=None,
- **kw_args
- ):
- """
- seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
- mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
- cache, should be first mems.shape[1] parts of context_tokens.
- mems are the first-level citizens here, but we don't assume what is memorized.
- input mems are used when multi-phase generation.
- """
- assert len(seqs.shape) == 2
- # building the initial tokens, attention_mask, and position_ids
- batch_size, context_length = seqs.shape
- seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
- tokens = seqs[..., :context_length]
- if attention_mask.dtype != torch.bool:
- attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
- # initialize generation
- counter = context_length - 1 # Last fixed index is ``counter''
- index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
- num_beams = 1
- # step-by-step generation
- while counter < seqs.shape[1] - 1:
- # Now, we want to generate seq[counter + 1],
- # token[:, index: counter+1] needs forwarding.
- # forward
- tokens = tokens.reshape(batch_size * num_beams, -1)
- mems = (
- mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
- if mems is not None
- else None
- )
- logits, *output_per_layers = model(
- tokens[:, index:],
- position_ids[..., index : counter + 1],
- attention_mask[..., index : counter + 1, : counter + 1], # TODO memlen
- mems=mems,
- **kw_args
- )
- mem_kv = [o["mem_kv"] for o in output_per_layers]
- mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
- if counter == context_length - 1:
- logits = logits[torch.arange(batch_size), context_lengths - 1]
- else:
- logits = logits[:, -1]
- counter += 1
- index = counter
- # if torch.distributed.get_rank() == 0:
- # print(f"counter: {counter}: logits: {logits.float().abs().mean()}")
- # sampling
- logits = logits.reshape(batch_size, num_beams, -1)
- tokens = tokens.reshape(batch_size, num_beams, -1)
- mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
- tokens, mems = strategy.forward(logits, tokens, mems)
- if len(tokens.shape) == 3 and num_beams == 1:
- num_beams = tokens.shape[1]
- position_ids = (
- position_ids.unsqueeze(1)
- .expand((batch_size, num_beams) + position_ids.shape[1:])
- .reshape((batch_size * num_beams,) + position_ids.shape[1:])
- )
- attention_mask_shape = attention_mask.shape[-3:]
- attention_mask = (
- attention_mask.unsqueeze(1)
- .expand(batch_size, num_beams, -1, -1, -1)
- .reshape(batch_size * num_beams, *attention_mask_shape)
- )
- if strategy.is_done:
- break
- return strategy.finalize(tokens, mems)
- class ModelForEvaluation(torch.nn.Module):
- def __init__(self, model, position_encoding_2d):
- super().__init__()
- self.model = model
- self.position_encoding_2d = position_encoding_2d
- self.device = next(self.model.parameters()).device
- @staticmethod
- def process_data(batch, device):
- return (
- batch["tokens"].to(device=device).long(),
- batch["position_ids"].to(device=device).long(),
- batch["attention_mask"].to(device=device).bool().unsqueeze(1),
- )
- def build_multiple_choice_sample(
- self,
- text,
- choices,
- is_single_token,
- unified_multitask_encoding=False,
- unidirectional=False,
- use_task_mask=False,
- ):
- tokenizer = get_tokenizer()
- sop_id = tokenizer.get_command("sop")
- mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
- token = np.array(text, dtype=np.int64)
- target = np.array(text, dtype=np.int64)
- position_id = np.arange(len(text), dtype=np.int64)
- block_position_id = np.zeros(len(text), dtype=np.int64)
- choice_target_id = []
- blank_filling = mask_id in text
- if not blank_filling:
- if unidirectional:
- assert use_task_mask, "Unidirectional attention only support gMASK"
- token = np.concatenate(([mask_id, sop_id], token[:-1]))
- target = np.concatenate(([mask_id, sop_id], target[:-1]))
- position_id = np.zeros(len(token), dtype=np.int64)
- if self.position_encoding_2d:
- block_position_id = np.arange(len(token), dtype=np.int64)
- mask_position = len(token)
- else:
- mask_position = len(token)
- token = np.concatenate((token, [mask_id]))
- target = np.concatenate((target, [mask_id]))
- position_id = np.arange(len(token), dtype=np.int64)
- if self.position_encoding_2d:
- block_position_id = np.zeros(len(token), dtype=np.int64)
- else:
- assert not unidirectional, "Unidirectional attention doesn't support blank filling"
- assert not use_task_mask, "Blank filling only support MASK"
- mask_position = text.index(mask_id)
- division = len(token)
- attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
- if unidirectional:
- attention_mask[0] = np.tril(attention_mask[0])
- for choice in choices:
- if not choice:
- choice = [tokenizer.get_command("eop")]
- target = np.concatenate((target, choice))
- choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
- attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
- if unidirectional:
- if self.position_encoding_2d:
- position_id = np.concatenate((position_id, [0] * len(choice)))
- block_position_id = np.concatenate(
- (block_position_id, np.arange(mask_position, mask_position + len(choice), dtype=np.int64))
- )
- else:
- position_id = np.concatenate(
- (
- position_id,
- np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
- )
- )
- token = np.concatenate((token, [text[-1]], choice[:-1]))
- else:
- if self.position_encoding_2d:
- position_id = np.concatenate((position_id, [mask_position] * len(choice)))
- block_position_id = np.concatenate(
- (block_position_id, np.arange(1, 1 + len(choice), dtype=np.int64))
- )
- else:
- position_id = np.concatenate(
- (
- position_id,
- [mask_position] * len(choice)
- if (blank_filling or not unified_multitask_encoding) and not use_task_mask
- else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
- )
- )
- token = np.concatenate((token, [sop_id], choice[:-1]))
- if is_single_token:
- break
- attention_mask = block_diag(*attention_mask)
- attention_mask[division:, :division] = 1
- if is_single_token:
- choices = np.array(choices, dtype=np.int64).squeeze().tolist()
- if self.position_encoding_2d:
- position_id = np.stack((position_id, block_position_id), axis=0)
- item = {
- "token": token,
- "position_id": position_id,
- "attention_mask": attention_mask,
- "choices": choices,
- "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
- }
- return item
- def cond_log_prob(self, batch) -> List[List[float]]:
- """
- @return: Conditional log probability of each option
- """
- tokens, position_ids, attention_mask = self.process_data(batch, self.device)
- choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
- is_single_token = batch["is_single_token"]
- self.model.eval()
- with torch.no_grad():
- logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
- logits_batch = torch.nn.functional.log_softmax(logits, dim=-1)
- # output: [b, sq, vocab]
- log_probs = []
- # if torch.distributed.get_rank() == 0:
- # import pdb
- #
- # pdb.set_trace()
- # torch.distributed.barrier()
- if is_single_token: # Single token
- for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
- log_probs.append(logits[choice_target_ids[0], choices].tolist())
- else: # Multi token
- for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
- log_probs_single = []
- for choice, choice_target_id in zip(choices, choice_target_ids):
- tmp = output[choice_target_id, choice]
- log_probs_single.append(tmp.sum().tolist())
- log_probs.append(log_probs_single)
- return log_probs
- def build_generation_sample(self, text, max_gen_length, use_task_mask, unidirectional):
- tokenizer = get_tokenizer()
- sop_id = tokenizer.get_command("sop")
- mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
- token = np.array(text, dtype=np.int64)
- position_id = np.arange(len(text), dtype=np.int64)
- block_position_id = np.zeros(len(text), dtype=np.int64)
- target_position_id = np.zeros(len(text), dtype=np.int64)
- target_block_position_id = np.zeros(len(text), dtype=np.int64)
- blank_filling = mask_id in text
- if unidirectional:
- assert use_task_mask, "Unidirectional attention only support gMASK"
- assert not blank_filling, "Unidirectional attention doesn't support blank filling"
- token = np.concatenate(([mask_id, sop_id], token))
- if self.position_encoding_2d:
- position_id = np.zeros(len(token), dtype=np.int64)
- target_position_id = np.zeros(max_gen_length, dtype=np.int64)
- block_position_id = np.arange(len(token), dtype=np.int64)
- target_block_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
- else:
- position_id = np.arange(len(token), dtype=np.int64)
- target_position_id = np.zeros(len(token), len(token) + max_gen_length, dtype=np.int64)
- else:
- if not blank_filling:
- mask_position = len(token)
- token = np.concatenate((token, [mask_id, sop_id]))
- else:
- assert not use_task_mask, "Blank filling only support MASK"
- mask_position = text.index(mask_id)
- token = np.concatenate((token, [sop_id]))
- position_id = np.concatenate((np.arange(len(token) - 1, dtype=np.int64), [mask_position]))
- target_position_id = np.full(max_gen_length, mask_position, dtype=np.int64)
- if self.position_encoding_2d:
- block_position_id = np.zeros(len(token), dtype=np.int64)
- target_block_position_id = np.arange(1, max_gen_length + 1, dtype=np.int64)
- elif use_task_mask:
- position_id = np.arange(len(token), dtype=np.int64)
- target_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
- context_length = len(token)
- attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
- if not unidirectional:
- attention_mask[: context_length - 1, : context_length - 1] = 1
- if self.position_encoding_2d:
- position_id = np.stack((position_id, block_position_id), axis=0)
- target_position_id = np.stack((target_position_id, target_block_position_id), axis=0)
- item = {
- "token": token,
- "position_id": position_id,
- "target_position_id": target_position_id,
- "attention_mask": attention_mask,
- "context_length": context_length,
- }
- return item
- def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[List[int]], List[List[List[int]]]]:
- """
- @return: A list of text model generated, sorted by score in descending order
- """
- seqs = sample["tokens"].to(device=self.device).long()
- context_lengths = sample["context_length"].long()
- def get_masks_and_position_ids(seq):
- batch_size = seq.shape[0]
- max_gen_length = sample["target_position_ids"].shape[-1]
- tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
- position_ids = torch.cat((sample["position_ids"], sample["target_position_ids"]), dim=-1)
- position_ids = position_ids.to(device=self.device).long()
- attention_mask = sample["attention_mask"].to(device=self.device)
- context_mask = (
- attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1, max_gen_length, 1)
- )
- causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
- generation_mask = torch.cat((context_mask, causal_mask), dim=-1)
- attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode="constant", value=1)
- attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
- attention_mask = attention_mask.bool().unsqueeze(1)
- return tokens, attention_mask, position_ids
- self.model.eval()
- with torch.no_grad():
- output = batch_filling_sequence(
- self.model,
- seqs,
- context_lengths,
- get_masks_and_position_ids=get_masks_and_position_ids,
- strategy=strategy,
- )[0]
- if isinstance(output, torch.Tensor): # different strategies
- output = output.tolist()
- output_targets = []
- context_length = seqs.shape[1]
- for lines in output:
- lines = lines.tolist() if isinstance(lines, torch.Tensor) else lines
- output_target = []
- if not isinstance(lines, list):
- lines = [lines]
- for line in lines:
- unfinished = line.index(-1) if -1 in line else len(line)
- if line[unfinished - 1] in strategy.end_tokens:
- unfinished -= 1
- line = line[context_length:unfinished]
- output_target.append(line)
- if not return_all_beams:
- output_targets.append(output_target[0])
- else:
- output_targets.append(output_target)
- return output_targets
- def build_language_model_sample(
- self,
- tokens: List[int],
- is_first_segment: bool,
- max_seq_length: int,
- generation_length: int,
- unidirectional: bool,
- use_gmask: bool,
- ):
- tokenizer = get_tokenizer()
- sop_id = tokenizer.get_command("sop")
- mask_id = tokenizer.get_command("[gMASK]") if use_gmask else tokenizer.get_command("[MASK]")
- if is_first_segment or unidirectional:
- prompt, text = [], tokens
- else:
- prompt_length = max_seq_length - 1 - generation_length
- prompt, text = tokens[:prompt_length], tokens[prompt_length:]
- seq_length = len(prompt) + len(text) + 1
- attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
- attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
- gen_length = min(len(text), generation_length)
- position_id = np.arange(0, seq_length, dtype=np.int64)
- if self.position_encoding_2d:
- position_id = np.concatenate(
- (np.arange(0, seq_length - gen_length, dtype=np.int64), [seq_length - gen_length - 1] * gen_length)
- )
- block_position_id = np.concatenate(
- ([0] * (seq_length - gen_length - 1), np.arange(0, gen_length + 1, dtype=np.int64))
- )
- position_id = np.stack((position_id, block_position_id), axis=0)
- return {
- "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
- "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
- "position_ids": position_id,
- "attention_mask": attention_mask < 0.5,
- "loss_masks": np.array(
- [0] * (seq_length - gen_length) + [1] * gen_length,
- dtype=np.int64,
- ),
- }
- def calculate_loss(self, batch) -> List[float]:
- tokens, position_ids, attention_mask = self.process_data(batch, self.device)
- targets, loss_masks = (
- batch["targets"].to(device=self.device).long(),
- batch["loss_masks"].to(device=self.device).long(),
- )
- original_parallel_output = self.model.transformer.parallel_output
- self.model.transformer.parallel_output = True
- self.model.eval()
- with torch.no_grad():
- logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
- losses = vocab_parallel_cross_entropy(logits.contiguous().float(), targets)
- loss = torch.sum(losses * loss_masks, dim=-1)
- self.model.transformer.parallel_output = original_parallel_output
- return loss.tolist()
|