|
@@ -12,7 +12,6 @@ from itertools import accumulate
|
|
|
from bisect import bisect_right
|
|
|
|
|
|
from SwissArmyTransformer import get_tokenizer
|
|
|
-from SwissArmyTransformer.mpu import get_model_parallel_rank
|
|
|
|
|
|
from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
|
|
|
from .utils import get_tokenized_input
|
|
@@ -58,7 +57,6 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
|
|
|
def has_collate_fn(self) -> bool:
|
|
|
return False
|
|
|
|
|
|
- @staticmethod
|
|
|
def collate_fn(self, samples):
|
|
|
return None
|
|
|
|
|
@@ -66,10 +64,10 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
|
|
|
with open(os.path.join(path), "r", encoding="utf-8") as file:
|
|
|
for line in file:
|
|
|
item = json.loads(line)
|
|
|
- self.data.append(self.process_single_item(item))
|
|
|
+ self.data.extend(self.process_single_item(item))
|
|
|
|
|
|
@abstractmethod
|
|
|
- def process_single_item(self, item) -> dict:
|
|
|
+ def process_single_item(self, item, **kwargs) -> List[dict]:
|
|
|
pass
|
|
|
|
|
|
def __len__(self):
|
|
@@ -79,19 +77,18 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
|
|
|
class GenerationTaskDataset(EvaluationDataset):
|
|
|
config: GenerationTaskConfig
|
|
|
|
|
|
- def process_single_item(self, item):
|
|
|
+ def process_single_item(self, item, **kwargs):
|
|
|
text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
|
|
|
if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
|
|
|
text_length = self.config.max_seq_length - self.config.max_gen_length - 2
|
|
|
text = text[len(text) - text_length : len(text)]
|
|
|
- return {"text": text, "targets": targets}
|
|
|
+ return [{"text": text, "targets": targets, **kwargs}]
|
|
|
|
|
|
@property
|
|
|
def has_collate_fn(self) -> bool:
|
|
|
return True
|
|
|
|
|
|
- @staticmethod
|
|
|
- def collate_fn(samples):
|
|
|
+ def collate_fn(self, samples):
|
|
|
TILE = 32
|
|
|
length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
|
|
|
|
|
@@ -105,8 +102,8 @@ class GenerationTaskDataset(EvaluationDataset):
|
|
|
token_batch.append(token)
|
|
|
position_id_batch.append(position_id)
|
|
|
attention_mask_batch.append(attention_mask)
|
|
|
- context_length_batch.append(sample["context_length"])
|
|
|
- target_position_id_batch.append(sample["target_position_id"])
|
|
|
+ context_length_batch.append(sample['context_length'])
|
|
|
+ target_position_id_batch.append(sample['target_position_id'])
|
|
|
return {
|
|
|
"tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
|
|
|
"position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
|
|
@@ -141,7 +138,7 @@ class GenerationTaskDataset(EvaluationDataset):
|
|
|
position_id = np.arange(0, context_length, dtype=np.int64)
|
|
|
target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
|
|
|
if not use_task_mask:
|
|
|
- position_id[context_length - 1 :] = mask_position
|
|
|
+ position_id[context_length - 1:] = mask_position
|
|
|
target_position_id[:] = mask_position
|
|
|
|
|
|
attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
|
|
@@ -159,12 +156,13 @@ class GenerationTaskDataset(EvaluationDataset):
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
item = self.data[idx]
|
|
|
- return self.build_generation_sample(
|
|
|
+ sample = self.build_generation_sample(
|
|
|
item["text"],
|
|
|
max_gen_length=self.config.max_gen_length,
|
|
|
use_task_mask=self.config.use_task_mask,
|
|
|
unidirectional=self.config.unidirectional,
|
|
|
)
|
|
|
+ return sample
|
|
|
|
|
|
|
|
|
class MultiChoiceTaskDataset(EvaluationDataset):
|
|
@@ -178,15 +176,13 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
def has_collate_fn(self) -> bool:
|
|
|
return True
|
|
|
|
|
|
- @staticmethod
|
|
|
- def collate_fn(samples):
|
|
|
+ def collate_fn(self, samples):
|
|
|
TILE = 32
|
|
|
length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
|
|
|
|
|
|
token_batch, position_id_batch, attention_mask_batch = [], [], []
|
|
|
choices_batch, choice_target_ids_batch = [], []
|
|
|
|
|
|
- is_single_token = True
|
|
|
for sample in samples:
|
|
|
token, position_id, attention_mask = pad_batch(
|
|
|
sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
|
|
@@ -196,8 +192,6 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
attention_mask_batch.append(attention_mask)
|
|
|
choices_batch.append(sample["choices"])
|
|
|
choice_target_ids_batch.append(sample["choice_target_ids"])
|
|
|
- if isinstance(sample["choice_target_ids"], list):
|
|
|
- is_single_token = False
|
|
|
|
|
|
return {
|
|
|
"tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
|
|
@@ -205,10 +199,10 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
"attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
|
|
|
"choices": choices_batch,
|
|
|
"choice_target_ids": choice_target_ids_batch,
|
|
|
- "is_single_token": is_single_token,
|
|
|
+ "is_single_token": self.is_single_token,
|
|
|
}
|
|
|
|
|
|
- def process_single_item(self, item):
|
|
|
+ def process_single_item(self, item, **kwargs):
|
|
|
text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
|
|
|
|
|
|
tgt_seq_length = sum([len(choice) for choice in choices])
|
|
@@ -228,16 +222,16 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
if tgt_seq_length != 1:
|
|
|
self.is_single_token = False
|
|
|
|
|
|
- return {
|
|
|
+ return [{
|
|
|
"text": text,
|
|
|
"choices": choices,
|
|
|
"label": label,
|
|
|
- }
|
|
|
+ **kwargs
|
|
|
+ }]
|
|
|
|
|
|
@staticmethod
|
|
|
- def build_multiple_choice_sample(
|
|
|
- text, choices, is_single_token, unified_multitask_encoding=False, use_task_mask=False
|
|
|
- ):
|
|
|
+ def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
|
|
|
+ unidirectional=False, use_task_mask=False):
|
|
|
tokenizer = get_tokenizer()
|
|
|
|
|
|
sop_id = tokenizer.get_command("sop")
|
|
@@ -250,48 +244,51 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
|
|
|
blank_filling = mask_id in text
|
|
|
if not blank_filling:
|
|
|
- mask_position = len(token)
|
|
|
- token = np.concatenate((token, [mask_id]))
|
|
|
- target = np.concatenate((target, [mask_id]))
|
|
|
- position_id = np.concatenate((position_id, [mask_position]))
|
|
|
+ if unidirectional:
|
|
|
+ assert use_task_mask
|
|
|
+ token = np.concatenate(([mask_id, sop_id], token[:-1]))
|
|
|
+ target = np.concatenate(([mask_id, sop_id], target[:-1]))
|
|
|
+ position_id = np.arange(len(token), dtype=np.int64)
|
|
|
+ mask_position = len(token)
|
|
|
+ else:
|
|
|
+ mask_position = len(token)
|
|
|
+ token = np.concatenate((token, [mask_id]))
|
|
|
+ target = np.concatenate((target, [mask_id]))
|
|
|
+ position_id = np.concatenate((position_id, [mask_position]))
|
|
|
else:
|
|
|
+ assert not unidirectional, "Unidirectional attention doesn't support blank filling"
|
|
|
+ assert not use_task_mask, "Unidirectional attention doesn't support task mask"
|
|
|
mask_position = text.index(mask_id)
|
|
|
|
|
|
division = len(token)
|
|
|
attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
|
|
|
+ if unidirectional:
|
|
|
+ attention_mask[0] = np.tril(attention_mask[0])
|
|
|
|
|
|
for choice in choices:
|
|
|
- if len(choice) == 0:
|
|
|
- if get_model_parallel_rank() == 0:
|
|
|
- print("Empty choice found")
|
|
|
- choice = [0]
|
|
|
- if use_task_mask == False:
|
|
|
- position_id = np.concatenate(
|
|
|
- (
|
|
|
- position_id,
|
|
|
- [mask_position] * len(choice)
|
|
|
- if blank_filling or not unified_multitask_encoding
|
|
|
- else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
|
|
|
- )
|
|
|
- )
|
|
|
- else:
|
|
|
- position_id = np.concatenate(
|
|
|
- (
|
|
|
- position_id,
|
|
|
- np.arange(division, division + len(choice), dtype=np.int64),
|
|
|
- )
|
|
|
+ if not choice:
|
|
|
+ choice = [tokenizer.get_command('eop')]
|
|
|
+ position_id = np.concatenate(
|
|
|
+ (
|
|
|
+ position_id,
|
|
|
+ [mask_position] * len(choice)
|
|
|
+ if (blank_filling or not unified_multitask_encoding) and not use_task_mask
|
|
|
+ else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
|
|
|
)
|
|
|
-
|
|
|
+ )
|
|
|
choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
|
|
|
attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
|
|
|
- token = np.concatenate((token, [sop_id], choice[:-1]))
|
|
|
+ if unidirectional:
|
|
|
+ token = np.concatenate((token, [text[-1]], choice[:-1]))
|
|
|
+ else:
|
|
|
+ token = np.concatenate((token, [sop_id], choice[:-1]))
|
|
|
target = np.concatenate((target, choice))
|
|
|
|
|
|
if is_single_token:
|
|
|
break
|
|
|
|
|
|
attention_mask = block_diag(*attention_mask)
|
|
|
- attention_mask[: len(token), :division] = 1
|
|
|
+ attention_mask[division:, :division] = 1
|
|
|
|
|
|
if is_single_token:
|
|
|
choices = np.array(choices, dtype=np.int64).squeeze().tolist()
|
|
@@ -307,13 +304,15 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
item = self.data[idx]
|
|
|
- return self.build_multiple_choice_sample(
|
|
|
+ sample = self.build_multiple_choice_sample(
|
|
|
item["text"],
|
|
|
item["choices"],
|
|
|
is_single_token=self.is_single_token,
|
|
|
unified_multitask_encoding=self.config.use_multitask_encoding,
|
|
|
+ unidirectional=self.config.unidirectional,
|
|
|
use_task_mask=self.config.use_task_mask,
|
|
|
)
|
|
|
+ return sample
|
|
|
|
|
|
|
|
|
class LanguageModelTaskDataset(EvaluationDataset):
|