generate.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import os
  2. import torch
  3. import stat
  4. import re
  5. from functools import partial
  6. from typing import List, Tuple
  7. from SwissArmyTransformer import mpu
  8. from evaluation.model import batch_filling_sequence
  9. from generation import BeamSearchStrategy, BaseStrategy
  10. from SwissArmyTransformer.generation.utils import timed_name, generate_continually
  11. from initialize import initialize, initialize_model_and_tokenizer
  12. def add_generation_specific_args(parser):
  13. parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
  14. parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
  15. parser.add_argument(
  16. "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
  17. )
  18. def isEnglish(s):
  19. try:
  20. s.encode(encoding="utf-8").decode("ascii")
  21. except UnicodeDecodeError:
  22. return False
  23. else:
  24. return True
  25. def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
  26. context_length = seq.shape[1]
  27. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
  28. attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
  29. attention_mask.tril_()
  30. attention_mask[..., : context_length - 1] = 1
  31. attention_mask.unsqueeze_(1)
  32. attention_mask = (attention_mask < 0.5).bool()
  33. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  34. if not gmask:
  35. position_ids[context_length - 1 :] = mask_position
  36. position_ids = position_ids.unsqueeze(0)
  37. return tokens, attention_mask, position_ids
  38. def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
  39. # add MASK
  40. generation_mask = "[MASK]" if "[MASK]" in raw_text else "[gMASK]"
  41. use_gmask = "[MASK]" not in raw_text
  42. mask_pattern = r"\[g?MASK\]"
  43. text_list = re.split(mask_pattern, raw_text)
  44. pattern_list = re.compile(mask_pattern).findall(raw_text)
  45. seq = []
  46. for i in range(len(pattern_list)):
  47. pattern = pattern_list[i]
  48. sub_text = text_list[i]
  49. seq.extend(tokenizer.tokenize(sub_text))
  50. seq.append(tokenizer.get_command(pattern))
  51. seq.extend(tokenizer.tokenize(text_list[-1]))
  52. if "MASK]" not in raw_text:
  53. seq += [tokenizer.get_command(generation_mask)]
  54. raw_text += " " + generation_mask
  55. if not raw_text.endswith("MASK]"):
  56. seq = seq + [tokenizer.get_command("eos")]
  57. if mpu.get_model_parallel_rank() == 0:
  58. print("\nInput: {}\n".format(raw_text))
  59. if len(seq) > args.max_sequence_length:
  60. raise ValueError("text too long.")
  61. # generation
  62. is_english = isEnglish(raw_text)
  63. output_list = [seq]
  64. num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
  65. last_pos, answers, answers_with_style, blanks = (
  66. [0] * num_output,
  67. ["" for _ in range(num_output)],
  68. ["" for _ in range(num_output)],
  69. [[] for _ in range(num_output)],
  70. )
  71. # continually detect the first mark position
  72. while True:
  73. seq = output_list[0]
  74. # detect mask position
  75. mask_token = tokenizer.get_command(generation_mask)
  76. if mask_token not in seq:
  77. break
  78. mask_position = seq.index(mask_token)
  79. output_list = []
  80. input_seq = torch.cuda.LongTensor(
  81. [seq + [tokenizer.get_command("sop")]],
  82. device=args.device,
  83. )
  84. output, _ = batch_filling_sequence(
  85. model,
  86. input_seq,
  87. torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
  88. strategy=strategy,
  89. get_masks_and_position_ids=partial(
  90. get_masks_and_position_ids,
  91. mask_position=mask_position,
  92. max_gen_length=args.out_seq_length - input_seq.shape[-1],
  93. gmask=use_gmask,
  94. ),
  95. )
  96. if isinstance(output, torch.Tensor): # different strategies
  97. output = list(output)
  98. else:
  99. output = output[0]
  100. output_list.extend(output)
  101. # clip -1s and fill back generated things into seq
  102. for i in range(len(output_list)):
  103. output = output_list[i].tolist()
  104. try:
  105. unfinished = output.index(-1)
  106. except ValueError:
  107. unfinished = len(output)
  108. if output[unfinished - 1] in strategy.end_tokens:
  109. unfinished -= 1
  110. bog = output.index(tokenizer.get_command("sop"))
  111. prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
  112. blank = tokenizer.detokenize(output[bog + 1 : unfinished])
  113. answers_with_style[i] += (
  114. prefix
  115. + (" " if is_english else "")
  116. + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
  117. + blank
  118. + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
  119. + (" " if is_english else "")
  120. )
  121. blanks[i].append(blank)
  122. last_pos[i] = mask_position + unfinished - (bog + 1)
  123. output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
  124. for i, output in enumerate(output_list):
  125. if output[-1] == tokenizer.get_command("eos"):
  126. output = output[:-1]
  127. answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
  128. answers[i] = tokenizer.detokenize(output)
  129. return answers, answers_with_style, blanks
  130. def main(args):
  131. model, tokenizer = initialize_model_and_tokenizer(args)
  132. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  133. if args.sampling_strategy == "BaseStrategy":
  134. strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
  135. elif args.sampling_strategy == "BeamSearchStrategy":
  136. strategy = BeamSearchStrategy(
  137. 1,
  138. args.num_beams,
  139. length_penalty=args.length_penalty,
  140. consider_end=True,
  141. end_tokens=end_tokens,
  142. no_repeat_ngram_size=args.no_repeat_ngram_size,
  143. min_gen_length=args.min_gen_length,
  144. )
  145. else:
  146. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  147. def process(raw_text):
  148. if args.with_id:
  149. query_id, raw_text = raw_text.split("\t")
  150. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
  151. # save
  152. if args.with_id:
  153. full_path = os.path.join(args.output_path, query_id + ".txt")
  154. else:
  155. prefix = raw_text.replace("/", "")[:20]
  156. full_path = timed_name(prefix, ".txt", args.output_path)
  157. if mpu.get_model_parallel_rank() == 0:
  158. if args.print_all_beams and len(answers) > 1:
  159. for idx, answer_with_style in enumerate(answers_with_style):
  160. print(f"Output beam {idx}:", answer_with_style) # print the first.
  161. if len(answer_with_style) > 120:
  162. print("")
  163. else:
  164. print(f"Output:", answers_with_style[0]) # print the first.
  165. with open(full_path, "w", encoding="utf-8") as fout:
  166. for answer in answers:
  167. fout.write(answer + "\n")
  168. os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
  169. os.makedirs(args.output_path, exist_ok=True)
  170. generate_continually(process, args.input_source)
  171. if __name__ == "__main__":
  172. args = initialize(extra_args_provider=add_generation_specific_args)
  173. with torch.no_grad():
  174. main(args)