|
@@ -1,9 +1,12 @@
|
|
|
|
+import numpy as np
|
|
import torch
|
|
import torch
|
|
|
|
|
|
from typing import List, Union
|
|
from typing import List, Union
|
|
|
|
+from scipy.linalg import block_diag
|
|
|
|
|
|
from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
|
|
from SwissArmyTransformer.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
|
|
from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
|
|
from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
|
|
|
|
+from SwissArmyTransformer import get_tokenizer
|
|
|
|
|
|
|
|
|
|
def batch_filling_sequence(
|
|
def batch_filling_sequence(
|
|
@@ -71,7 +74,9 @@ def batch_filling_sequence(
|
|
if len(tokens.shape) == 3 and num_beams == 1:
|
|
if len(tokens.shape) == 3 and num_beams == 1:
|
|
num_beams = tokens.shape[1]
|
|
num_beams = tokens.shape[1]
|
|
position_ids = (
|
|
position_ids = (
|
|
- position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
|
|
|
|
|
|
+ position_ids.unsqueeze(1)
|
|
|
|
+ .expand((batch_size, num_beams) + position_ids.shape[1:])
|
|
|
|
+ .reshape((batch_size * num_beams,) + position_ids.shape[1:])
|
|
)
|
|
)
|
|
attention_mask_shape = attention_mask.shape[-3:]
|
|
attention_mask_shape = attention_mask.shape[-3:]
|
|
attention_mask = (
|
|
attention_mask = (
|
|
@@ -85,10 +90,11 @@ def batch_filling_sequence(
|
|
|
|
|
|
|
|
|
|
class ModelForEvaluation(torch.nn.Module):
|
|
class ModelForEvaluation(torch.nn.Module):
|
|
- def __init__(self, model):
|
|
|
|
|
|
+ def __init__(self, model, position_encoding_2d):
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
self.model = model
|
|
self.model = model
|
|
|
|
+ self.position_encoding_2d = position_encoding_2d
|
|
self.device = next(self.model.parameters()).device
|
|
self.device = next(self.model.parameters()).device
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
@@ -99,6 +105,115 @@ class ModelForEvaluation(torch.nn.Module):
|
|
batch["attention_mask"].to(device=device).bool().unsqueeze(1),
|
|
batch["attention_mask"].to(device=device).bool().unsqueeze(1),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ def build_multiple_choice_sample(
|
|
|
|
+ self,
|
|
|
|
+ text,
|
|
|
|
+ choices,
|
|
|
|
+ is_single_token,
|
|
|
|
+ unified_multitask_encoding=False,
|
|
|
|
+ unidirectional=False,
|
|
|
|
+ use_task_mask=False,
|
|
|
|
+ ):
|
|
|
|
+ tokenizer = get_tokenizer()
|
|
|
|
+
|
|
|
|
+ sop_id = tokenizer.get_command("sop")
|
|
|
|
+ mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
|
|
|
|
+
|
|
|
|
+ token = np.array(text, dtype=np.int64)
|
|
|
|
+ target = np.array(text, dtype=np.int64)
|
|
|
|
+ position_id = np.arange(len(text), dtype=np.int64)
|
|
|
|
+ block_position_id = np.zeros(len(text), dtype=np.int64)
|
|
|
|
+ choice_target_id = []
|
|
|
|
+
|
|
|
|
+ blank_filling = mask_id in text
|
|
|
|
+ if not blank_filling:
|
|
|
|
+ if unidirectional:
|
|
|
|
+ assert use_task_mask, "Unidirectional attention only support gMASK"
|
|
|
|
+ token = np.concatenate(([mask_id, sop_id], token[:-1]))
|
|
|
|
+ target = np.concatenate(([mask_id, sop_id], target[:-1]))
|
|
|
|
+ position_id = np.zeros(len(token), dtype=np.int64)
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ block_position_id = np.arange(len(token), dtype=np.int64)
|
|
|
|
+ mask_position = len(token)
|
|
|
|
+ else:
|
|
|
|
+ mask_position = len(token)
|
|
|
|
+ token = np.concatenate((token, [mask_id]))
|
|
|
|
+ target = np.concatenate((target, [mask_id]))
|
|
|
|
+ position_id = np.arange(len(token), dtype=np.int64)
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ block_position_id = np.zeros(len(token), dtype=np.int64)
|
|
|
|
+ else:
|
|
|
|
+ assert not unidirectional, "Unidirectional attention doesn't support blank filling"
|
|
|
|
+ assert not use_task_mask, "Blank filling only support MASK"
|
|
|
|
+ mask_position = text.index(mask_id)
|
|
|
|
+
|
|
|
|
+ division = len(token)
|
|
|
|
+ attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
|
|
|
|
+ if unidirectional:
|
|
|
|
+ attention_mask[0] = np.tril(attention_mask[0])
|
|
|
|
+
|
|
|
|
+ for choice in choices:
|
|
|
|
+ if not choice:
|
|
|
|
+ choice = [tokenizer.get_command("eop")]
|
|
|
|
+
|
|
|
|
+ target = np.concatenate((target, choice))
|
|
|
|
+ choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
|
|
|
|
+ attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
|
|
|
|
+
|
|
|
|
+ if unidirectional:
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ position_id = np.concatenate((position_id, [0] * len(choice)))
|
|
|
|
+ block_position_id = np.concatenate(
|
|
|
|
+ (block_position_id, np.arange(mask_position, mask_position + len(choice), dtype=np.int64))
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ position_id = np.concatenate(
|
|
|
|
+ (
|
|
|
|
+ position_id,
|
|
|
|
+ np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ token = np.concatenate((token, [text[-1]], choice[:-1]))
|
|
|
|
+ else:
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ position_id = np.concatenate((position_id, [mask_position] * len(choice)))
|
|
|
|
+ block_position_id = np.concatenate(
|
|
|
|
+ (block_position_id, np.arange(1, 1 + len(choice), dtype=np.int64))
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ position_id = np.concatenate(
|
|
|
|
+ (
|
|
|
|
+ position_id,
|
|
|
|
+ [mask_position] * len(choice)
|
|
|
|
+ if (blank_filling or not unified_multitask_encoding) and not use_task_mask
|
|
|
|
+ else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ token = np.concatenate((token, [sop_id], choice[:-1]))
|
|
|
|
+
|
|
|
|
+ if is_single_token:
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ attention_mask = block_diag(*attention_mask)
|
|
|
|
+ attention_mask[division:, :division] = 1
|
|
|
|
+
|
|
|
|
+ if is_single_token:
|
|
|
|
+ choices = np.array(choices, dtype=np.int64).squeeze().tolist()
|
|
|
|
+
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ position_id = np.stack((position_id, block_position_id), axis=0)
|
|
|
|
+
|
|
|
|
+ item = {
|
|
|
|
+ "token": token,
|
|
|
|
+ "position_id": position_id,
|
|
|
|
+ "attention_mask": attention_mask,
|
|
|
|
+ "choices": choices,
|
|
|
|
+ "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
|
|
|
|
+ }
|
|
|
|
+ return item
|
|
|
|
+
|
|
def cond_log_prob(self, batch) -> List[List[float]]:
|
|
def cond_log_prob(self, batch) -> List[List[float]]:
|
|
"""
|
|
"""
|
|
@return: Conditional log probability of each option
|
|
@return: Conditional log probability of each option
|
|
@@ -115,6 +230,12 @@ class ModelForEvaluation(torch.nn.Module):
|
|
# output: [b, sq, vocab]
|
|
# output: [b, sq, vocab]
|
|
log_probs = []
|
|
log_probs = []
|
|
|
|
|
|
|
|
+ # if torch.distributed.get_rank() == 0:
|
|
|
|
+ # import pdb
|
|
|
|
+ #
|
|
|
|
+ # pdb.set_trace()
|
|
|
|
+ # torch.distributed.barrier()
|
|
|
|
+
|
|
if is_single_token: # Single token
|
|
if is_single_token: # Single token
|
|
for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
|
|
for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
|
|
log_probs.append(logits[choice_target_ids[0], choices].tolist())
|
|
log_probs.append(logits[choice_target_ids[0], choices].tolist())
|
|
@@ -184,6 +305,52 @@ class ModelForEvaluation(torch.nn.Module):
|
|
output_targets.append(output_target)
|
|
output_targets.append(output_target)
|
|
return output_targets
|
|
return output_targets
|
|
|
|
|
|
|
|
+ def build_language_model_sample(
|
|
|
|
+ self,
|
|
|
|
+ tokens: List[int],
|
|
|
|
+ is_first_segment: bool,
|
|
|
|
+ max_seq_length: int,
|
|
|
|
+ generation_length: int,
|
|
|
|
+ unidirectional: bool,
|
|
|
|
+ use_gmask: bool,
|
|
|
|
+ ):
|
|
|
|
+ tokenizer = get_tokenizer()
|
|
|
|
+ sop_id = tokenizer.get_command("sop")
|
|
|
|
+ mask_id = tokenizer.get_command("[gMASK]") if use_gmask else tokenizer.get_command("[MASK]")
|
|
|
|
+
|
|
|
|
+ if is_first_segment or unidirectional:
|
|
|
|
+ prompt, text = [], tokens
|
|
|
|
+ else:
|
|
|
|
+ prompt_length = max_seq_length - 1 - generation_length
|
|
|
|
+ prompt, text = tokens[:prompt_length], tokens[prompt_length:]
|
|
|
|
+
|
|
|
|
+ seq_length = len(prompt) + len(text) + 1
|
|
|
|
+ attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
|
|
|
|
+ attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
|
|
|
|
+
|
|
|
|
+ gen_length = min(len(text), generation_length)
|
|
|
|
+
|
|
|
|
+ position_id = np.arange(0, seq_length, dtype=np.int64)
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ position_id = np.concatenate(
|
|
|
|
+ (np.arange(0, seq_length - gen_length, dtype=np.int64), [seq_length - gen_length - 1] * gen_length)
|
|
|
|
+ )
|
|
|
|
+ block_position_id = np.concatenate(
|
|
|
|
+ ([0] * (seq_length - gen_length - 1), np.arange(0, gen_length + 1, dtype=np.int64))
|
|
|
|
+ )
|
|
|
|
+ position_id = np.stack((position_id, block_position_id), axis=0)
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
|
|
|
|
+ "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
|
|
|
|
+ "position_ids": position_id,
|
|
|
|
+ "attention_mask": attention_mask < 0.5,
|
|
|
|
+ "loss_masks": np.array(
|
|
|
|
+ [0] * (seq_length - gen_length) + [1] * gen_length,
|
|
|
|
+ dtype=np.int64,
|
|
|
|
+ ),
|
|
|
|
+ }
|
|
|
|
+
|
|
def calculate_loss(self, batch) -> List[float]:
|
|
def calculate_loss(self, batch) -> List[float]:
|
|
tokens, position_ids, attention_mask = self.process_data(batch, self.device)
|
|
tokens, position_ids, attention_mask = self.process_data(batch, self.device)
|
|
targets, loss_masks = (
|
|
targets, loss_masks = (
|