|
@@ -371,41 +371,62 @@ class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
|
|
return 3
|
|
return 3
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
- def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
|
|
|
|
|
|
+ def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False,
|
|
|
|
+ unidirectional=False):
|
|
tokenizer = get_tokenizer()
|
|
tokenizer = get_tokenizer()
|
|
cls_id = tokenizer.get_command("ENC")
|
|
cls_id = tokenizer.get_command("ENC")
|
|
eos_id = tokenizer.get_command("eos")
|
|
eos_id = tokenizer.get_command("eos")
|
|
sop_id = tokenizer.get_command("sop")
|
|
sop_id = tokenizer.get_command("sop")
|
|
- mask_id = tokenizer.get_command("[MASK]")
|
|
|
|
|
|
+ mask_id = tokenizer.get_command("[MASK]") if not unidirectional else tokenizer.get_command("[gMASK]")
|
|
blank_filling = mask_id in text
|
|
blank_filling = mask_id in text
|
|
- if not blank_filling:
|
|
|
|
- text = text + [mask_id]
|
|
|
|
- text = [cls_id] + text + [eos_id]
|
|
|
|
|
|
+ text_length = len(text)
|
|
|
|
+ last_token = text[-1]
|
|
|
|
+
|
|
|
|
+ if unidirectional:
|
|
|
|
+ assert not blank_filling
|
|
|
|
+ text = [cls_id, mask_id, eos_id, sop_id] + text[:-1]
|
|
|
|
+ position_id = np.array(list(range(3)) + [1] * text_length, dtype=np.int64)
|
|
|
|
+ block_position_id = np.array([0] * 3 + list(range(1, text_length + 1)), dtype=np.int64)
|
|
|
|
+ else:
|
|
|
|
+ if not blank_filling:
|
|
|
|
+ text = text + [mask_id]
|
|
|
|
+ text = [cls_id] + text + [eos_id]
|
|
|
|
+ position_id = np.arange(len(text), dtype=np.int64)
|
|
|
|
+ block_position_id = np.zeros(len(text), dtype=np.int64)
|
|
|
|
|
|
token = np.array(text, dtype=np.int64)
|
|
token = np.array(text, dtype=np.int64)
|
|
target = np.array(text, dtype=np.int64)
|
|
target = np.array(text, dtype=np.int64)
|
|
- position_id = np.arange(len(text), dtype=np.int64)
|
|
|
|
- block_position_id = np.zeros(len(text), dtype=np.int64)
|
|
|
|
mask_position = text.index(mask_id)
|
|
mask_position = text.index(mask_id)
|
|
choice_target_id = []
|
|
choice_target_id = []
|
|
|
|
|
|
-
|
|
|
|
division = len(token)
|
|
division = len(token)
|
|
- attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
|
|
|
|
|
|
+ if unidirectional:
|
|
|
|
+ attention_mask = [np.tril(np.ones((len(token), len(token)), dtype=np.int64))]
|
|
|
|
+ attention_mask[0][:3, :3] = 1
|
|
|
|
+ else:
|
|
|
|
+ attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
|
|
|
|
|
|
for choice in choices:
|
|
for choice in choices:
|
|
- position_id = np.concatenate((position_id, [mask_position] * len(choice)))
|
|
|
|
- block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
|
|
|
|
choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
|
|
choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
|
|
attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
|
|
attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
|
|
- token = np.concatenate((token, [sop_id], choice[:-1]))
|
|
|
|
|
|
+ position_id = np.concatenate((position_id, [mask_position] * len(choice)))
|
|
|
|
+ if unidirectional:
|
|
|
|
+ block_position_id = np.concatenate(
|
|
|
|
+ (block_position_id, range(1 + text_length, 1 + text_length + len(choice))))
|
|
|
|
+ token = np.concatenate((token, [last_token], choice[:-1]))
|
|
|
|
+ else:
|
|
|
|
+ block_position_id = np.concatenate((block_position_id, range(1, 1 + len(choice))))
|
|
|
|
+ token = np.concatenate((token, [sop_id], choice[:-1]))
|
|
target = np.concatenate((target, choice))
|
|
target = np.concatenate((target, choice))
|
|
|
|
|
|
if is_single_token:
|
|
if is_single_token:
|
|
break
|
|
break
|
|
|
|
|
|
attention_mask = block_diag(*attention_mask)
|
|
attention_mask = block_diag(*attention_mask)
|
|
- attention_mask[: len(token), :division] = 1
|
|
|
|
|
|
+ if unidirectional:
|
|
|
|
+ attention_mask[division:, :division] = 1
|
|
|
|
+ else:
|
|
|
|
+ attention_mask[: len(token), :division] = 1
|
|
|
|
|
|
if is_single_token:
|
|
if is_single_token:
|
|
choices = np.array(choices, dtype=np.int64).squeeze().tolist()
|
|
choices = np.array(choices, dtype=np.int64).squeeze().tolist()
|
|
@@ -421,6 +442,16 @@ class SmallMultiChoiceTaskDataset(MultiChoiceTaskDataset):
|
|
}
|
|
}
|
|
return item
|
|
return item
|
|
|
|
|
|
|
|
+ def __getitem__(self, idx):
|
|
|
|
+ item = self.data[idx]
|
|
|
|
+ return self.build_multiple_choice_sample(
|
|
|
|
+ item["text"],
|
|
|
|
+ item["choices"],
|
|
|
|
+ is_single_token=self.is_single_token,
|
|
|
|
+ unified_multitask_encoding=self.config.use_multitask_encoding,
|
|
|
|
+ unidirectional=self.config.unidirectional
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
|
|
class LanguageModelTaskDataset(EvaluationDataset):
|
|
class LanguageModelTaskDataset(EvaluationDataset):
|
|
config: LanguageModelTaskConfig
|
|
config: LanguageModelTaskConfig
|