|
@@ -168,14 +168,19 @@ class GenerationTaskDataset(EvaluationDataset):
|
|
|
|
|
|
|
|
|
class SmallGenerationTaskDataset(GenerationTaskDataset):
|
|
|
- config: GenerationTaskConfig
|
|
|
+ def process_single_item(self, item):
|
|
|
+ text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
|
|
|
+ if len(text) + self.config.max_gen_length + 3 > self.config.max_seq_length:
|
|
|
+ text_length = self.config.max_seq_length - self.config.max_gen_length - 3
|
|
|
+ text = text[len(text) - text_length : len(text)]
|
|
|
+ return {"text": text, "targets": targets}
|
|
|
|
|
|
@staticmethod
|
|
|
def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
|
|
|
tokenizer = get_tokenizer()
|
|
|
|
|
|
sop_id = tokenizer.get_command("sop")
|
|
|
- mask_id = tokenizer.get_command("[gMASK]").Id if use_task_mask else tokenizer.get_command("[MASK]").Id
|
|
|
+ mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
|
|
|
cls_id = tokenizer.get_command("ENC")
|
|
|
eos_id = tokenizer.get_command("eos")
|
|
|
|
|
@@ -232,6 +237,10 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
def has_collate_fn(self) -> bool:
|
|
|
return True
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def num_special_tokens():
|
|
|
+ return 2
|
|
|
+
|
|
|
@staticmethod
|
|
|
def collate_fn(samples):
|
|
|
TILE = 32
|
|
@@ -263,7 +272,9 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
}
|
|
|
|
|
|
def process_single_item(self, item):
|
|
|
- text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
|
|
|
+ text = get_tokenized_input(item, "inputs", no_tokenized=self.config.no_tokenized)
|
|
|
+ choices = get_tokenized_input(item, "choices", no_tokenized=self.config.no_tokenized)
|
|
|
+ label = item["label"]
|
|
|
|
|
|
tgt_seq_length = sum([len(choice) for choice in choices])
|
|
|
if tgt_seq_length == len(choices):
|
|
@@ -271,9 +282,9 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
tgt_seq_length = 1
|
|
|
|
|
|
assert tgt_seq_length < self.config.max_seq_length
|
|
|
- if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
|
|
|
- text_length = self.config.max_seq_length - tgt_seq_length - 2
|
|
|
- text = text[len(text) - text_length : len(text)]
|
|
|
+ if len(text) + tgt_seq_length + self.num_special_tokens() > self.config.max_seq_length:
|
|
|
+ text_length = self.config.max_seq_length - tgt_seq_length - self.num_special_tokens()
|
|
|
+ text = text[len(text) - text_length: len(text)]
|
|
|
|
|
|
assert not (
|
|
|
self.mask_id in text and self.config.use_multitask_encoding
|
|
@@ -354,6 +365,63 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
)
|
|
|
|
|
|
|
|
|
+class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
|
|
|
+ @staticmethod
|
|
|
+ def num_special_tokens():
|
|
|
+ return 3
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
|
|
|
+ tokenizer = get_tokenizer()
|
|
|
+ cls_id = tokenizer.get_command("ENC")
|
|
|
+ eos_id = tokenizer.get_command("eos")
|
|
|
+ sop_id = tokenizer.get_command("sop")
|
|
|
+ mask_id = tokenizer.get_command("[MASK]")
|
|
|
+ blank_filling = mask_id in text
|
|
|
+ if not blank_filling:
|
|
|
+ text = text + [mask_id]
|
|
|
+ text = [cls_id] + text + [eos_id]
|
|
|
+
|
|
|
+ token = np.array(text, dtype=np.int64)
|
|
|
+ target = np.array(text, dtype=np.int64)
|
|
|
+ position_id = np.arange(len(text), dtype=np.int64)
|
|
|
+ block_position_id = np.zeros(len(text), dtype=np.int64)
|
|
|
+ mask_position = text.index(mask_id)
|
|
|
+ choice_target_id = []
|
|
|
+
|
|
|
+
|
|
|
+ division = len(token)
|
|
|
+ attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
|
|
|
+
|
|
|
+ for choice in choices:
|
|
|
+ position_id = np.concatenate((position_id, [mask_position] * len(choice)))
|
|
|
+ block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
|
|
|
+ choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
|
|
|
+ attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
|
|
|
+ token = np.concatenate((token, [sop_id], choice[:-1]))
|
|
|
+ target = np.concatenate((target, choice))
|
|
|
+
|
|
|
+ if is_single_token:
|
|
|
+ break
|
|
|
+
|
|
|
+ attention_mask = block_diag(*attention_mask)
|
|
|
+ attention_mask[: len(token), :division] = 1
|
|
|
+
|
|
|
+ if is_single_token:
|
|
|
+ choices = np.array(choices, dtype=np.int64).squeeze().tolist()
|
|
|
+
|
|
|
+ position_id = np.stack((position_id, block_position_id), axis=0)
|
|
|
+
|
|
|
+ item = {
|
|
|
+ "token": token,
|
|
|
+ "position_id": position_id,
|
|
|
+ "attention_mask": attention_mask,
|
|
|
+ "choices": choices,
|
|
|
+ "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
|
|
|
+ }
|
|
|
+ return item
|
|
|
+
|
|
|
+
|
|
|
class LanguageModelTaskDataset(EvaluationDataset):
|
|
|
config: LanguageModelTaskConfig
|
|
|
|