generate.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import torch
  3. import stat
  4. import re
  5. from functools import partial
  6. from typing import List, Tuple
  7. from SwissArmyTransformer import mpu
  8. from evaluation.model import batch_filling_sequence
  9. from generation import BeamSearchStrategy, BaseStrategy
  10. from generation import BeamSearchStrategy
  11. from SwissArmyTransformer.generation.utils import timed_name, generate_continually
  12. from initialize import initialize, initialize_model_and_tokenizer
  13. def add_generation_specific_args(parser):
  14. parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
  15. parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
  16. parser.add_argument(
  17. "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
  18. )
  19. def isEnglish(s):
  20. try:
  21. s.encode(encoding="utf-8").decode("ascii")
  22. except UnicodeDecodeError:
  23. return False
  24. else:
  25. return True
  26. def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
  27. context_length = seq.shape[1]
  28. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
  29. attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
  30. attention_mask.tril_()
  31. attention_mask[..., : context_length - 1] = 1
  32. attention_mask.unsqueeze_(1)
  33. attention_mask = (attention_mask < 0.5).bool()
  34. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  35. if not gmask:
  36. position_ids[context_length - 1 :] = mask_position
  37. position_ids = position_ids.unsqueeze(0)
  38. return tokens, attention_mask, position_ids
  39. def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
  40. # add MASK
  41. generation_mask = "[MASK]" if "[MASK]" in raw_text else "[gMASK]"
  42. use_gmask = "[MASK]" not in raw_text
  43. mask_pattern = r"\[g?MASK\]"
  44. text_list = re.split(mask_pattern, raw_text)
  45. pattern_list = re.compile(mask_pattern).findall(raw_text)
  46. seq = []
  47. for i in range(len(pattern_list)):
  48. pattern = pattern_list[i]
  49. sub_text = text_list[i]
  50. seq.extend(tokenizer.tokenize(sub_text))
  51. seq.append(tokenizer.get_command(pattern))
  52. seq.extend(tokenizer.tokenize(text_list[-1]))
  53. if "MASK]" not in raw_text:
  54. seq += [tokenizer.get_command(generation_mask)]
  55. raw_text += " " + generation_mask
  56. if not raw_text.endswith("MASK]"):
  57. seq = seq + [tokenizer.get_command("eos")]
  58. if mpu.get_model_parallel_rank() == 0:
  59. print("\nInput: {}\n".format(raw_text))
  60. if len(seq) > args.max_sequence_length:
  61. raise ValueError("text too long.")
  62. # generation
  63. is_english = isEnglish(raw_text)
  64. output_list = [seq]
  65. num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
  66. last_pos, answers, answers_with_style, blanks = (
  67. [0] * num_output,
  68. ["" for _ in range(num_output)],
  69. ["" for _ in range(num_output)],
  70. [[] for _ in range(num_output)],
  71. )
  72. # continually detect the first mark position
  73. while True:
  74. seq = output_list[0]
  75. # detect mask position
  76. mask_token = tokenizer.get_command(generation_mask)
  77. if mask_token not in seq:
  78. break
  79. mask_position = seq.index(mask_token)
  80. output_list = []
  81. input_seq = torch.cuda.LongTensor(
  82. [seq + [tokenizer.get_command("sop")]],
  83. device=args.device,
  84. )
  85. output, _ = batch_filling_sequence(
  86. model,
  87. input_seq,
  88. torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
  89. strategy=strategy,
  90. get_masks_and_position_ids=partial(
  91. get_masks_and_position_ids,
  92. mask_position=mask_position,
  93. max_gen_length=args.out_seq_length - input_seq.shape[-1],
  94. gmask=use_gmask,
  95. ),
  96. )
  97. if isinstance(output, torch.Tensor): # different strategies
  98. output = list(output)
  99. else:
  100. output = output[0]
  101. output_list.extend(output)
  102. # clip -1s and fill back generated things into seq
  103. for i in range(len(output_list)):
  104. output = output_list[i].tolist()
  105. try:
  106. unfinished = output.index(-1)
  107. except ValueError:
  108. unfinished = len(output)
  109. if output[unfinished - 1] in strategy.end_tokens:
  110. unfinished -= 1
  111. bog = output.index(tokenizer.get_command("sop"))
  112. prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
  113. blank = tokenizer.detokenize(output[bog + 1 : unfinished])
  114. answers_with_style[i] += (
  115. prefix
  116. + (" " if is_english else "")
  117. + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
  118. + blank
  119. + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
  120. + (" " if is_english else "")
  121. )
  122. blanks[i].append(blank)
  123. last_pos[i] = mask_position + unfinished - (bog + 1)
  124. output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
  125. for i, output in enumerate(output_list):
  126. if output[-1] == tokenizer.get_command("eos"):
  127. output = output[:-1]
  128. answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
  129. answers[i] = tokenizer.detokenize(output)
  130. return answers, answers_with_style, blanks
  131. def main(args):
  132. model, tokenizer = initialize_model_and_tokenizer(args)
  133. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  134. if args.sampling_strategy == "BaseStrategy":
  135. strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
  136. elif args.sampling_strategy == "BeamSearchStrategy":
  137. strategy = BeamSearchStrategy(
  138. 1,
  139. args.num_beams,
  140. length_penalty=args.length_penalty,
  141. consider_end=True,
  142. end_tokens=end_tokens,
  143. no_repeat_ngram_size=args.no_repeat_ngram_size,
  144. min_gen_length=args.min_gen_length,
  145. )
  146. else:
  147. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  148. def process(raw_text):
  149. if args.with_id:
  150. query_id, raw_text = raw_text.split("\t")
  151. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
  152. # save
  153. if args.with_id:
  154. full_path = os.path.join(args.output_path, query_id + ".txt")
  155. else:
  156. prefix = raw_text.replace("/", "")[:20]
  157. full_path = timed_name(prefix, ".txt", args.output_path)
  158. if mpu.get_model_parallel_rank() == 0:
  159. if args.print_all_beams and len(answers) > 1:
  160. for idx, answer_with_style in enumerate(answers_with_style):
  161. print(f"Output beam {idx}:", answer_with_style) # print the first.
  162. if len(answer_with_style) > 120:
  163. print("")
  164. else:
  165. print(f"Output:", answers_with_style[0]) # print the first.
  166. with open(full_path, "w", encoding="utf-8") as fout:
  167. for answer in answers:
  168. fout.write(answer + "\n")
  169. os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
  170. os.makedirs(args.output_path, exist_ok=True)
  171. generate_continually(process, args.input_source)
  172. if __name__ == "__main__":
  173. args = initialize(extra_args_provider=add_generation_specific_args)
  174. with torch.no_grad():
  175. main(args)