|
@@ -6,7 +6,7 @@ import re
|
|
|
from functools import partial
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
-from SwissArmyTransformer import mpu
|
|
|
+from SwissArmyTransformer import mpu, get_args
|
|
|
from evaluation.model import batch_filling_sequence
|
|
|
from generation import BeamSearchStrategy, BaseStrategy
|
|
|
from SwissArmyTransformer.generation.utils import timed_name, generate_continually
|
|
@@ -59,7 +59,7 @@ def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False,
|
|
|
return tokens, attention_mask, position_ids
|
|
|
|
|
|
|
|
|
-def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
|
|
|
+def fill_blanks(raw_text: str, model, tokenizer, strategy, args) -> Tuple[List[str], List[str], List[List[str]]]:
|
|
|
# add MASK
|
|
|
generation_mask = "[gMASK]"
|
|
|
if "[MASK]" in raw_text:
|
|
@@ -168,7 +168,9 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
|
|
|
return answers, answers_with_style, blanks
|
|
|
|
|
|
|
|
|
-def main(args):
|
|
|
+def main():
|
|
|
+ args = initialize(extra_args_provider=add_generation_specific_args)
|
|
|
+
|
|
|
model, tokenizer = initialize_model_and_tokenizer(args)
|
|
|
|
|
|
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
|
|
@@ -194,7 +196,7 @@ def main(args):
|
|
|
if args.with_id:
|
|
|
query_id, raw_text = raw_text.split("\t")
|
|
|
|
|
|
- answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
|
|
|
+ answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy, args)
|
|
|
|
|
|
# save
|
|
|
if args.with_id:
|
|
@@ -221,7 +223,5 @@ def main(args):
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- args = initialize(extra_args_provider=add_generation_specific_args)
|
|
|
-
|
|
|
with torch.no_grad():
|
|
|
- main(args)
|
|
|
+ main()
|