|
@@ -32,7 +32,7 @@ def isEnglish(s):
|
|
|
|
|
|
def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
|
|
def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
|
|
context_length = seq.shape[1]
|
|
context_length = seq.shape[1]
|
|
- tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
|
|
|
|
|
|
+ tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
|
|
attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
|
|
attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
|
|
attention_mask.tril_()
|
|
attention_mask.tril_()
|
|
attention_mask[..., : context_length - 1] = 1
|
|
attention_mask[..., : context_length - 1] = 1
|
|
@@ -50,10 +50,14 @@ def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
|
|
|
|
|
|
def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
|
|
def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
|
|
# add MASK
|
|
# add MASK
|
|
- generation_mask = "[MASK]" if "[MASK]" in raw_text else "[gMASK]"
|
|
|
|
- use_gmask = "[MASK]" not in raw_text
|
|
|
|
-
|
|
|
|
- mask_pattern = r"\[g?MASK\]"
|
|
|
|
|
|
+ generation_mask = "[gMASK]"
|
|
|
|
+ if "[MASK]" in raw_text:
|
|
|
|
+ generation_mask = "[MASK]"
|
|
|
|
+ elif "[sMASK]" in raw_text:
|
|
|
|
+ generation_mask = "[sMASK]"
|
|
|
|
+ use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
|
|
|
|
+
|
|
|
|
+ mask_pattern = r"\[[sg]?MASK\]"
|
|
text_list = re.split(mask_pattern, raw_text)
|
|
text_list = re.split(mask_pattern, raw_text)
|
|
pattern_list = re.compile(mask_pattern).findall(raw_text)
|
|
pattern_list = re.compile(mask_pattern).findall(raw_text)
|
|
seq = []
|
|
seq = []
|
|
@@ -158,8 +162,9 @@ def main(args):
|
|
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
|
|
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
|
|
|
|
|
|
if args.sampling_strategy == "BaseStrategy":
|
|
if args.sampling_strategy == "BaseStrategy":
|
|
- strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
|
|
|
|
- end_tokens=end_tokens)
|
|
|
|
|
|
+ strategy = BaseStrategy(
|
|
|
|
+ batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
|
|
|
|
+ )
|
|
elif args.sampling_strategy == "BeamSearchStrategy":
|
|
elif args.sampling_strategy == "BeamSearchStrategy":
|
|
strategy = BeamSearchStrategy(
|
|
strategy = BeamSearchStrategy(
|
|
1,
|
|
1,
|