generate.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import os
  2. import torch
  3. import stat
  4. import re
  5. from functools import partial
  6. from typing import List, Tuple
  7. from SwissArmyTransformer import mpu
  8. from evaluation.model import batch_filling_sequence
  9. from generation import BeamSearchStrategy, BaseStrategy
  10. from SwissArmyTransformer.generation.utils import timed_name, generate_continually
  11. from initialize import initialize, initialize_model_and_tokenizer
  12. def add_generation_specific_args(parser):
  13. parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
  14. parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
  15. parser.add_argument(
  16. "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
  17. )
  18. def isEnglish(s):
  19. try:
  20. s.encode(encoding="utf-8").decode("ascii")
  21. except UnicodeDecodeError:
  22. return False
  23. else:
  24. return True
  25. def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
  26. context_length = seq.shape[1]
  27. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
  28. attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
  29. attention_mask.tril_()
  30. attention_mask[..., : context_length - 1] = 1
  31. attention_mask.unsqueeze_(1)
  32. attention_mask = (attention_mask < 0.5).bool()
  33. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  34. if not gmask:
  35. position_ids[context_length - 1 :] = mask_position
  36. position_ids = position_ids.unsqueeze(0)
  37. return tokens, attention_mask, position_ids
  38. def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
  39. # add MASK
  40. generation_mask = "[gMASK]"
  41. if "[MASK]" in raw_text:
  42. generation_mask = "[MASK]"
  43. elif "[sMASK]" in raw_text:
  44. generation_mask = "[sMASK]"
  45. use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
  46. mask_pattern = r"\[[sg]?MASK\]"
  47. text_list = re.split(mask_pattern, raw_text)
  48. pattern_list = re.compile(mask_pattern).findall(raw_text)
  49. seq = []
  50. for i in range(len(pattern_list)):
  51. pattern = pattern_list[i]
  52. sub_text = text_list[i]
  53. seq.extend(tokenizer.tokenize(sub_text))
  54. seq.append(tokenizer.get_command(pattern))
  55. seq.extend(tokenizer.tokenize(text_list[-1]))
  56. if "MASK]" not in raw_text:
  57. seq += [tokenizer.get_command(generation_mask)]
  58. raw_text += " " + generation_mask
  59. if not raw_text.endswith("MASK]"):
  60. seq = seq + [tokenizer.get_command("eos")]
  61. if mpu.get_model_parallel_rank() == 0:
  62. print("\nInput: {}\n".format(raw_text))
  63. if len(seq) > args.max_sequence_length:
  64. raise ValueError("text too long.")
  65. # generation
  66. is_english = isEnglish(raw_text)
  67. output_list = [seq]
  68. num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
  69. last_pos, answers, answers_with_style, blanks = (
  70. [0] * num_output,
  71. ["" for _ in range(num_output)],
  72. ["" for _ in range(num_output)],
  73. [[] for _ in range(num_output)],
  74. )
  75. # continually detect the first mark position
  76. while True:
  77. seq = output_list[0]
  78. # detect mask position
  79. mask_token = tokenizer.get_command(generation_mask)
  80. if mask_token not in seq:
  81. break
  82. mask_position = seq.index(mask_token)
  83. output_list = []
  84. input_seq = torch.cuda.LongTensor(
  85. [seq + [tokenizer.get_command("sop")]],
  86. device=args.device,
  87. )
  88. output, _ = batch_filling_sequence(
  89. model,
  90. input_seq,
  91. torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
  92. strategy=strategy,
  93. get_masks_and_position_ids=partial(
  94. get_masks_and_position_ids,
  95. mask_position=mask_position,
  96. max_gen_length=args.out_seq_length - input_seq.shape[-1],
  97. gmask=use_gmask,
  98. ),
  99. )
  100. if isinstance(output, torch.Tensor): # different strategies
  101. output = output.tolist()
  102. output = output[0] # batch_size = 1
  103. output_list.extend(output)
  104. # clip -1s and fill back generated things into seq
  105. for i in range(len(output_list)):
  106. output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
  107. try:
  108. unfinished = output.index(-1)
  109. except ValueError:
  110. unfinished = len(output)
  111. if output[unfinished - 1] in strategy.end_tokens:
  112. unfinished -= 1
  113. bog = output.index(tokenizer.get_command("sop"))
  114. prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
  115. blank = tokenizer.detokenize(output[bog + 1 : unfinished])
  116. answers_with_style[i] += (
  117. prefix
  118. + (" " if is_english else "")
  119. + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
  120. + blank
  121. + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
  122. + (" " if is_english else "")
  123. )
  124. blanks[i].append(blank)
  125. last_pos[i] = mask_position + unfinished - (bog + 1)
  126. output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
  127. for i, output in enumerate(output_list):
  128. if output[-1] == tokenizer.get_command("eos"):
  129. output = output[:-1]
  130. answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
  131. answers[i] = tokenizer.detokenize(output)
  132. return answers, answers_with_style, blanks
  133. def main(args):
  134. model, tokenizer = initialize_model_and_tokenizer(args)
  135. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  136. if args.sampling_strategy == "BaseStrategy":
  137. strategy = BaseStrategy(
  138. batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
  139. )
  140. elif args.sampling_strategy == "BeamSearchStrategy":
  141. strategy = BeamSearchStrategy(
  142. 1,
  143. args.num_beams,
  144. length_penalty=args.length_penalty,
  145. consider_end=True,
  146. end_tokens=end_tokens,
  147. no_repeat_ngram_size=args.no_repeat_ngram_size,
  148. min_gen_length=args.min_gen_length,
  149. )
  150. else:
  151. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  152. def process(raw_text):
  153. if args.with_id:
  154. query_id, raw_text = raw_text.split("\t")
  155. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
  156. # save
  157. if args.with_id:
  158. full_path = os.path.join(args.output_path, query_id + ".txt")
  159. else:
  160. prefix = raw_text.replace("/", "")[:20]
  161. full_path = timed_name(prefix, ".txt", args.output_path)
  162. if mpu.get_model_parallel_rank() == 0:
  163. if args.print_all_beams and len(answers) > 1:
  164. for idx, answer_with_style in enumerate(answers_with_style):
  165. print(f"Output beam {idx}:", answer_with_style) # print the first.
  166. if len(answer_with_style) > 120:
  167. print("")
  168. else:
  169. print(f"Output:", answers_with_style[0]) # print the first.
  170. with open(full_path, "w", encoding="utf-8") as fout:
  171. for answer in answers:
  172. fout.write(answer + "\n")
  173. os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
  174. os.makedirs(args.output_path, exist_ok=True)
  175. generate_continually(process, args.input_source)
  176. if __name__ == "__main__":
  177. args = initialize(extra_args_provider=add_generation_specific_args)
  178. with torch.no_grad():
  179. main(args)