|
@@ -248,6 +248,68 @@ class ModelForEvaluation(torch.nn.Module):
|
|
log_probs.append(log_probs_single)
|
|
log_probs.append(log_probs_single)
|
|
return log_probs
|
|
return log_probs
|
|
|
|
|
|
|
|
+ def build_generation_sample(self, text, max_gen_length, use_task_mask, unidirectional):
|
|
|
|
+ tokenizer = get_tokenizer()
|
|
|
|
+
|
|
|
|
+ sop_id = tokenizer.get_command("sop")
|
|
|
|
+ mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
|
|
|
|
+
|
|
|
|
+ token = np.array(text, dtype=np.int64)
|
|
|
|
+ position_id = np.arange(len(text), dtype=np.int64)
|
|
|
|
+ block_position_id = np.zeros(len(text), dtype=np.int64)
|
|
|
|
+ target_position_id = np.zeros(len(text), dtype=np.int64)
|
|
|
|
+ target_block_position_id = np.zeros(len(text), dtype=np.int64)
|
|
|
|
+
|
|
|
|
+ blank_filling = mask_id in text
|
|
|
|
+
|
|
|
|
+ if unidirectional:
|
|
|
|
+ assert use_task_mask, "Unidirectional attention only support gMASK"
|
|
|
|
+ assert not blank_filling, "Unidirectional attention doesn't support blank filling"
|
|
|
|
+ token = np.concatenate(([mask_id, sop_id], token))
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ position_id = np.zeros(len(token), dtype=np.int64)
|
|
|
|
+ target_position_id = np.zeros(max_gen_length, dtype=np.int64)
|
|
|
|
+ block_position_id = np.arange(len(token), dtype=np.int64)
|
|
|
|
+ target_block_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
|
|
|
|
+ else:
|
|
|
|
+ position_id = np.arange(len(token), dtype=np.int64)
|
|
|
|
+ target_position_id = np.zeros(len(token), len(token) + max_gen_length, dtype=np.int64)
|
|
|
|
+ else:
|
|
|
|
+ if not blank_filling:
|
|
|
|
+ mask_position = len(token)
|
|
|
|
+ token = np.concatenate((token, [mask_id, sop_id]))
|
|
|
|
+ else:
|
|
|
|
+ assert not use_task_mask, "Blank filling only support MASK"
|
|
|
|
+ mask_position = text.index(mask_id)
|
|
|
|
+ token = np.concatenate((token, [sop_id]))
|
|
|
|
+
|
|
|
|
+ position_id = np.concatenate((np.arange(len(token) - 1, dtype=np.int64), [mask_position]))
|
|
|
|
+ target_position_id = np.full(max_gen_length, mask_position, dtype=np.int64)
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ block_position_id = np.zeros(len(token), dtype=np.int64)
|
|
|
|
+ target_block_position_id = np.arange(1, max_gen_length + 1, dtype=np.int64)
|
|
|
|
+ elif use_task_mask:
|
|
|
|
+ position_id = np.arange(len(token), dtype=np.int64)
|
|
|
|
+ target_position_id = np.arange(len(token), len(token) + max_gen_length, dtype=np.int64)
|
|
|
|
+
|
|
|
|
+ context_length = len(token)
|
|
|
|
+ attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
|
|
|
|
+ if not unidirectional:
|
|
|
|
+ attention_mask[: context_length - 1, : context_length - 1] = 1
|
|
|
|
+
|
|
|
|
+ if self.position_encoding_2d:
|
|
|
|
+ position_id = np.stack((position_id, block_position_id), axis=0)
|
|
|
|
+ target_position_id = np.stack((target_position_id, target_block_position_id), axis=0)
|
|
|
|
+
|
|
|
|
+ item = {
|
|
|
|
+ "token": token,
|
|
|
|
+ "position_id": position_id,
|
|
|
|
+ "target_position_id": target_position_id,
|
|
|
|
+ "attention_mask": attention_mask,
|
|
|
|
+ "context_length": context_length,
|
|
|
|
+ }
|
|
|
|
+ return item
|
|
|
|
+
|
|
def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[List[int]], List[List[List[int]]]]:
|
|
def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[List[int]], List[List[List[int]]]]:
|
|
"""
|
|
"""
|
|
@return: A list of text model generated, sorted by score in descending order
|
|
@return: A list of text model generated, sorted by score in descending order
|