|
@@ -18,14 +18,15 @@ from .utils import get_tokenized_input
|
|
|
|
|
|
|
|
|
def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
|
|
|
+ pad_length = max_seq_length - len(tokens)
|
|
|
attention_mask = np.pad(
|
|
|
attention_mask,
|
|
|
- pad_width=((0, max_seq_length - len(tokens)),),
|
|
|
+ pad_width=((0, pad_length),),
|
|
|
mode="constant",
|
|
|
constant_values=0,
|
|
|
)
|
|
|
- tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
|
|
|
- position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
|
|
|
+ tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))
|
|
|
+ position_ids = np.concatenate((position_ids, position_ids[..., -1:].repeat(pad_length, -1)), axis=-1)
|
|
|
return tokens, position_ids, attention_mask
|
|
|
|
|
|
|
|
@@ -166,6 +167,60 @@ class GenerationTaskDataset(EvaluationDataset):
|
|
|
)
|
|
|
|
|
|
|
|
|
+class SmallGenerationTaskDataset(GenerationTaskDataset):
|
|
|
+ config: GenerationTaskConfig
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
|
|
|
+ tokenizer = get_tokenizer()
|
|
|
+
|
|
|
+ sop_id = tokenizer.get_command("sop")
|
|
|
+ mask_id = tokenizer.get_command("[gMASK]").Id if use_task_mask else tokenizer.get_command("[MASK]").Id
|
|
|
+ cls_id = tokenizer.get_command("ENC")
|
|
|
+ eos_id = tokenizer.get_command("eos")
|
|
|
+
|
|
|
+ token = np.array(text, dtype=np.int64)
|
|
|
+
|
|
|
+ blank_filling = mask_id in text
|
|
|
+ if blank_filling:
|
|
|
+ assert not unidirectional, "Unidirectional attention doesn't support blank filling"
|
|
|
+ assert not use_task_mask, "Unidirectional attention doesn't support task mask"
|
|
|
+ mask_position = text.index(mask_id) + 1
|
|
|
+ context_length = len(token) + 2
|
|
|
+ token = np.concatenate(([cls_id], token, [eos_id, sop_id]))
|
|
|
+ else:
|
|
|
+ if unidirectional:
|
|
|
+ mask_position = 1
|
|
|
+ context_length = 3
|
|
|
+ token = np.concatenate(([cls_id, mask_id, eos_id, sop_id], token))
|
|
|
+ else:
|
|
|
+ mask_position = len(token) + 1
|
|
|
+ context_length = len(token) + 3
|
|
|
+ token = np.concatenate(([cls_id], token, [mask_id, eos_id, sop_id]))
|
|
|
+ prefix_length = len(token) - context_length
|
|
|
+
|
|
|
+ position_id = [list(range(context_length)) + [mask_position] * prefix_length,
|
|
|
+ [0] * context_length + list(range(1, prefix_length + 1))]
|
|
|
+ position_id = np.array(position_id, dtype=np.int64)
|
|
|
+
|
|
|
+ target_position_id = [[mask_position] * max_gen_length,
|
|
|
+ list(range(prefix_length + 1, prefix_length + max_gen_length + 1))]
|
|
|
+ target_position_id = np.array(target_position_id, dtype=np.int64)
|
|
|
+
|
|
|
+ attention_mask = np.tril(np.ones((len(token), len(token)), dtype=np.int64))
|
|
|
+ if not unidirectional:
|
|
|
+ attention_mask[: len(token) - 1, : len(token) - 1] = 1
|
|
|
+
|
|
|
+ item = {
|
|
|
+ "token": token,
|
|
|
+ "position_id": position_id,
|
|
|
+ "target_position_id": target_position_id,
|
|
|
+ "attention_mask": attention_mask,
|
|
|
+ "context_length": context_length,
|
|
|
+ }
|
|
|
+ return item
|
|
|
+
|
|
|
+
|
|
|
class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
config: MultiChoiceTaskConfig
|
|
|
|