generate.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import os
  2. import torch
  3. import stat
  4. import re
  5. from functools import partial
  6. from typing import List, Tuple
  7. from SwissArmyTransformer import mpu
  8. from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
  9. from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
  10. from generation import BeamSearchStrategy
  11. from SwissArmyTransformer.generation.utils import timed_name, generate_continually
  12. from initialize import initialize, initialize_model_and_tokenizer
  13. def add_generation_specific_args(parser):
  14. parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
  15. parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
  16. parser.add_argument(
  17. "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
  18. )
  19. def isEnglish(s):
  20. try:
  21. s.encode(encoding="utf-8").decode("ascii")
  22. except UnicodeDecodeError:
  23. return False
  24. else:
  25. return True
  26. def get_masks_and_position_ids(seq, mask_position, context_length, gmask=False):
  27. tokens = seq.unsqueeze(0)
  28. attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
  29. attention_mask.tril_()
  30. attention_mask[..., : context_length - 1] = 1
  31. attention_mask.unsqueeze_(1)
  32. attention_mask = (attention_mask < 0.5).bool()
  33. position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
  34. if not gmask:
  35. position_ids[context_length - 1 :] = mask_position
  36. position_ids = position_ids.unsqueeze(0)
  37. return tokens, attention_mask, position_ids
  38. def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
  39. # add MASK
  40. generation_mask = "[MASK]" if "[MASK]" in raw_text else "[gMASK]"
  41. use_gmask = "[MASK]" not in raw_text
  42. mask_pattern = r"\[g?MASK\]"
  43. text_list = re.split(mask_pattern, raw_text)
  44. pattern_list = re.compile(mask_pattern).findall(raw_text)
  45. seq = []
  46. for i in range(len(pattern_list)):
  47. pattern = pattern_list[i]
  48. sub_text = text_list[i]
  49. seq.extend(tokenizer.tokenize(sub_text))
  50. seq.append(tokenizer.get_command(pattern))
  51. seq.extend(tokenizer.tokenize(text_list[-1]))
  52. if "MASK]" not in raw_text:
  53. seq += [tokenizer.get_command(generation_mask)]
  54. raw_text += " " + generation_mask
  55. if not raw_text.endswith("MASK]"):
  56. seq = seq + [tokenizer.get_command("eos")]
  57. if mpu.get_model_parallel_rank() == 0:
  58. print("\nInput: {}\n".format(raw_text))
  59. if len(seq) > args.max_sequence_length:
  60. raise ValueError("text too long.")
  61. # generation
  62. is_english = isEnglish(raw_text)
  63. output_list = [seq]
  64. num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
  65. last_pos, answers, answers_with_style, blanks = (
  66. [0] * num_output,
  67. ["" for _ in range(num_output)],
  68. ["" for _ in range(num_output)],
  69. [[] for _ in range(num_output)],
  70. )
  71. # continually detect the first mark position
  72. while True:
  73. seq = output_list[0]
  74. # detect mask position
  75. mask_token = tokenizer.get_command(generation_mask)
  76. if mask_token not in seq:
  77. break
  78. mask_position = seq.index(mask_token)
  79. output_list = []
  80. input_seq = torch.cuda.LongTensor(
  81. seq + [tokenizer.get_command("sop")] + [-1] * (args.out_seq_length - len(seq) - 1),
  82. device=args.device,
  83. )
  84. output, _ = filling_sequence(
  85. model,
  86. input_seq,
  87. batch_size=num_output,
  88. strategy=strategy,
  89. log_attention_weights=None,
  90. get_masks_and_position_ids=partial(
  91. get_masks_and_position_ids,
  92. mask_position=mask_position,
  93. context_length=len(seq) + 1,
  94. gmask=use_gmask,
  95. ),
  96. )
  97. if isinstance(output, torch.Tensor): # different strategies
  98. output = list(output)
  99. output_list.extend(output)
  100. # clip -1s and fill back generated things into seq
  101. for i in range(len(output_list)):
  102. output = output_list[i].tolist()
  103. try:
  104. unfinished = output.index(-1)
  105. except ValueError:
  106. unfinished = len(output)
  107. if output[unfinished - 1] in strategy.end_tokens:
  108. unfinished -= 1
  109. bog = output.index(tokenizer.get_command("sop"))
  110. prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
  111. blank = tokenizer.detokenize(output[bog + 1 : unfinished])
  112. answers_with_style[i] += (
  113. prefix
  114. + (" " if is_english else "")
  115. + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
  116. + blank
  117. + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
  118. + (" " if is_english else "")
  119. )
  120. blanks[i].append(blank)
  121. last_pos[i] = mask_position + unfinished - (bog + 1)
  122. output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
  123. for i, output in enumerate(output_list):
  124. if output[-1] == tokenizer.get_command("eos"):
  125. output = output[:-1]
  126. answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
  127. answers[i] = tokenizer.detokenize(output)
  128. return answers, answers_with_style, blanks
  129. def main(args):
  130. model, tokenizer = initialize_model_and_tokenizer(args)
  131. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  132. if args.sampling_strategy == "BaseStrategy":
  133. strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens)
  134. elif args.sampling_strategy == "BeamSearchStrategy":
  135. strategy = BeamSearchStrategy(
  136. args.num_beams,
  137. length_penalty=args.length_penalty,
  138. consider_end=True,
  139. end_tokens=end_tokens,
  140. no_repeat_ngram_size=args.no_repeat_ngram_size,
  141. min_gen_length=args.min_gen_length,
  142. )
  143. else:
  144. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  145. def process(raw_text):
  146. if args.with_id:
  147. query_id, raw_text = raw_text.split("\t")
  148. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
  149. # save
  150. if args.with_id:
  151. full_path = os.path.join(args.output_path, query_id + ".txt")
  152. else:
  153. prefix = raw_text.replace("/", "")[:20]
  154. full_path = timed_name(prefix, ".txt", args.output_path)
  155. if mpu.get_model_parallel_rank() == 0:
  156. if args.print_all_beams and len(answers) > 1:
  157. for idx, answer_with_style in enumerate(answers_with_style):
  158. print(f"Output beam {idx}:", answer_with_style) # print the first.
  159. if len(answer_with_style) > 120:
  160. print("")
  161. else:
  162. print(f"Output:", answers_with_style[0]) # print the first.
  163. with open(full_path, "w", encoding="utf-8") as fout:
  164. for answer in answers:
  165. fout.write(answer + "\n")
  166. os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
  167. os.makedirs(args.output_path, exist_ok=True)
  168. generate_continually(process, args.input_source)
  169. if __name__ == "__main__":
  170. args = initialize(extra_args_provider=add_generation_specific_args)
  171. with torch.no_grad():
  172. main(args)