generate.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import os
  2. import torch
  3. import stat
  4. import re
  5. from functools import partial
  6. from typing import List, Tuple
  7. from SwissArmyTransformer import mpu, get_args
  8. from evaluation.model import batch_filling_sequence
  9. from generation import BeamSearchStrategy, BaseStrategy
  10. from SwissArmyTransformer.generation.utils import timed_name, generate_continually
  11. from initialize import initialize, initialize_model_and_tokenizer
  12. def add_generation_specific_args(parser):
  13. parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
  14. parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
  15. parser.add_argument(
  16. "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
  17. )
  18. def isEnglish(s):
  19. try:
  20. s.encode(encoding="utf-8").decode("ascii")
  21. except UnicodeDecodeError:
  22. return False
  23. else:
  24. return True
  25. def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False, position_encoding_2d=False):
  26. context_length = seq.shape[1]
  27. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
  28. attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
  29. attention_mask.tril_()
  30. attention_mask[..., : context_length - 1] = 1
  31. attention_mask.unsqueeze_(1)
  32. attention_mask = (attention_mask < 0.5).bool()
  33. if position_encoding_2d:
  34. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  35. position_ids[context_length - 1 :] = mask_position
  36. block_position_ids = torch.cat(
  37. (
  38. torch.zeros(context_length - 2, dtype=torch.long, device=tokens.device),
  39. torch.arange(tokens.shape[-1] - (context_length - 2), dtype=torch.long, device=tokens.device),
  40. )
  41. )
  42. position_ids = torch.vstack((position_ids, block_position_ids))
  43. else:
  44. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  45. if not gmask:
  46. position_ids[context_length - 1 :] = mask_position
  47. position_ids = position_ids.unsqueeze(0)
  48. return tokens, attention_mask, position_ids
  49. def fill_blanks(raw_text: str, model, tokenizer, strategy, args) -> Tuple[List[str], List[str], List[List[str]]]:
  50. # add MASK
  51. generation_mask = "[gMASK]"
  52. if "[MASK]" in raw_text:
  53. generation_mask = "[MASK]"
  54. elif "[sMASK]" in raw_text:
  55. generation_mask = "[sMASK]"
  56. use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
  57. mask_pattern = r"\[[sg]?MASK\]"
  58. text_list = re.split(mask_pattern, raw_text)
  59. pattern_list = re.compile(mask_pattern).findall(raw_text)
  60. seq = []
  61. for i in range(len(pattern_list)):
  62. pattern = pattern_list[i]
  63. sub_text = text_list[i]
  64. seq.extend(tokenizer.tokenize(sub_text))
  65. seq.append(tokenizer.get_command(pattern))
  66. seq.extend(tokenizer.tokenize(text_list[-1]))
  67. if "MASK]" not in raw_text:
  68. seq += [tokenizer.get_command(generation_mask)]
  69. raw_text += " " + generation_mask
  70. if not raw_text.endswith("MASK]"):
  71. seq = seq + [tokenizer.get_command("eos")]
  72. if mpu.get_model_parallel_rank() == 0:
  73. print("\nInput: {}\n".format(raw_text))
  74. if len(seq) > args.max_sequence_length:
  75. raise ValueError("text too long.")
  76. # generation
  77. is_english = isEnglish(raw_text)
  78. output_list = [seq]
  79. num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
  80. last_pos, answers, answers_with_style, blanks = (
  81. [0] * num_output,
  82. ["" for _ in range(num_output)],
  83. ["" for _ in range(num_output)],
  84. [[] for _ in range(num_output)],
  85. )
  86. # continually detect the first mark position
  87. while True:
  88. seq = output_list[0]
  89. # detect mask position
  90. mask_token = tokenizer.get_command(generation_mask)
  91. if mask_token not in seq:
  92. break
  93. mask_position = seq.index(mask_token)
  94. output_list = []
  95. input_seq = torch.cuda.LongTensor(
  96. [seq + [tokenizer.get_command("sop")]],
  97. device=args.device,
  98. )
  99. output, _ = batch_filling_sequence(
  100. model,
  101. input_seq,
  102. torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
  103. strategy=strategy,
  104. get_masks_and_position_ids=partial(
  105. get_masks_and_position_ids,
  106. mask_position=mask_position,
  107. max_gen_length=args.out_seq_length,
  108. gmask=use_gmask,
  109. position_encoding_2d=args.position_encoding_2d,
  110. ),
  111. )
  112. if isinstance(output, torch.Tensor): # different strategies
  113. output = output.tolist()
  114. output = output[0] # batch_size = 1
  115. output_list.extend(output)
  116. # clip -1s and fill back generated things into seq
  117. for i in range(len(output_list)):
  118. output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
  119. try:
  120. unfinished = output.index(-1)
  121. except ValueError:
  122. unfinished = len(output)
  123. if output[unfinished - 1] in strategy.end_tokens:
  124. unfinished -= 1
  125. bog = output.index(tokenizer.get_command("sop"))
  126. prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
  127. blank = tokenizer.detokenize(output[bog + 1 : unfinished])
  128. answers_with_style[i] += (
  129. prefix
  130. + (" " if is_english else "")
  131. + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
  132. + blank
  133. + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
  134. + (" " if is_english else "")
  135. )
  136. blanks[i].append(blank)
  137. last_pos[i] = mask_position + unfinished - (bog + 1)
  138. output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
  139. for i, output in enumerate(output_list):
  140. if output[-1] == tokenizer.get_command("eos"):
  141. output = output[:-1]
  142. answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
  143. answers[i] = tokenizer.detokenize(output)
  144. return answers, answers_with_style, blanks
  145. def main():
  146. args = initialize(extra_args_provider=add_generation_specific_args)
  147. model, tokenizer = initialize_model_and_tokenizer(args)
  148. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  149. if args.sampling_strategy == "BaseStrategy":
  150. strategy = BaseStrategy(
  151. batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
  152. )
  153. elif args.sampling_strategy == "BeamSearchStrategy":
  154. strategy = BeamSearchStrategy(
  155. 1,
  156. args.num_beams,
  157. length_penalty=args.length_penalty,
  158. consider_end=True,
  159. end_tokens=end_tokens,
  160. no_repeat_ngram_size=args.no_repeat_ngram_size,
  161. min_gen_length=args.min_gen_length,
  162. )
  163. else:
  164. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  165. def process(raw_text):
  166. if args.with_id:
  167. query_id, raw_text = raw_text.split("\t")
  168. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy, args)
  169. # save
  170. if args.with_id:
  171. full_path = os.path.join(args.output_path, query_id + ".txt")
  172. else:
  173. prefix = raw_text.replace("/", "")[:20]
  174. full_path = timed_name(prefix, ".txt", args.output_path)
  175. if mpu.get_model_parallel_rank() == 0:
  176. if args.print_all_beams and len(answers) > 1:
  177. for idx, answer_with_style in enumerate(answers_with_style):
  178. print(f"Output beam {idx}:", answer_with_style) # print the first.
  179. if len(answer_with_style) > 120:
  180. print("")
  181. else:
  182. print(f"Output:", answers_with_style[0]) # print the first.
  183. with open(full_path, "w", encoding="utf-8") as fout:
  184. for answer in answers:
  185. fout.write(answer + "\n")
  186. os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
  187. os.makedirs(args.output_path, exist_ok=True)
  188. generate_continually(process, args.input_source)
  189. if __name__ == "__main__":
  190. with torch.no_grad():
  191. main()