Przeglądaj źródła

Add sMASK for generation

Sengxian 2 lat temu
rodzic
commit
ae68f7c0d2
1 zmienionych plików z 12 dodań i 7 usunięć
  1. 12 7
      generate.py

+ 12 - 7
generate.py

@@ -32,7 +32,7 @@ def isEnglish(s):
 
 def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
     context_length = seq.shape[1]
-    tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+    tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
     attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
     attention_mask.tril_()
     attention_mask[..., : context_length - 1] = 1
@@ -50,10 +50,14 @@ def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
 
 def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
     # add MASK
-    generation_mask = "[MASK]" if "[MASK]" in raw_text else "[gMASK]"
-    use_gmask = "[MASK]" not in raw_text
-
-    mask_pattern = r"\[g?MASK\]"
+    generation_mask = "[gMASK]"
+    if "[MASK]" in raw_text:
+        generation_mask = "[MASK]"
+    elif "[sMASK]" in raw_text:
+        generation_mask = "[sMASK]"
+    use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
+
+    mask_pattern = r"\[[sg]?MASK\]"
     text_list = re.split(mask_pattern, raw_text)
     pattern_list = re.compile(mask_pattern).findall(raw_text)
     seq = []
@@ -158,8 +162,9 @@ def main(args):
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
 
     if args.sampling_strategy == "BaseStrategy":
-        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
-                                end_tokens=end_tokens)
+        strategy = BaseStrategy(
+            batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
+        )
     elif args.sampling_strategy == "BeamSearchStrategy":
         strategy = BeamSearchStrategy(
             1,