2
0

generate.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import os
  2. import torch
  3. import stat
  4. import re
  5. from functools import partial
  6. from typing import List, Tuple
  7. from SwissArmyTransformer import mpu
  8. from evaluation.model import batch_filling_sequence
  9. from generation import BeamSearchStrategy, BaseStrategy
  10. from SwissArmyTransformer.generation.utils import timed_name, generate_continually
  11. from initialize import initialize, initialize_model_and_tokenizer
  12. def add_generation_specific_args(parser):
  13. parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
  14. parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
  15. parser.add_argument(
  16. "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
  17. )
  18. def isEnglish(s):
  19. try:
  20. s.encode(encoding="utf-8").decode("ascii")
  21. except UnicodeDecodeError:
  22. return False
  23. else:
  24. return True
  25. def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False, position_encoding_2d=False):
  26. context_length = seq.shape[1]
  27. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
  28. attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
  29. attention_mask.tril_()
  30. attention_mask[..., : context_length - 1] = 1
  31. attention_mask.unsqueeze_(1)
  32. attention_mask = (attention_mask < 0.5).bool()
  33. if position_encoding_2d:
  34. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  35. position_ids[context_length - 1 :] = mask_position
  36. block_position_ids = torch.cat(
  37. (
  38. torch.zeros(context_length - 2, dtype=torch.long, device=tokens.device),
  39. torch.arange(tokens.shape[-1] - (context_length - 2), dtype=torch.long, device=tokens.device),
  40. )
  41. )
  42. position_ids = torch.vstack((position_ids, block_position_ids))
  43. else:
  44. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  45. if not gmask:
  46. position_ids[context_length - 1 :] = mask_position
  47. position_ids = position_ids.unsqueeze(0)
  48. return tokens, attention_mask, position_ids
  49. def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
  50. # add MASK
  51. generation_mask = "[gMASK]"
  52. if "[MASK]" in raw_text:
  53. generation_mask = "[MASK]"
  54. elif "[sMASK]" in raw_text:
  55. generation_mask = "[sMASK]"
  56. use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
  57. mask_pattern = r"\[[sg]?MASK\]"
  58. text_list = re.split(mask_pattern, raw_text)
  59. pattern_list = re.compile(mask_pattern).findall(raw_text)
  60. seq = []
  61. for i in range(len(pattern_list)):
  62. pattern = pattern_list[i]
  63. sub_text = text_list[i]
  64. seq.extend(tokenizer.tokenize(sub_text))
  65. seq.append(tokenizer.get_command(pattern))
  66. seq.extend(tokenizer.tokenize(text_list[-1]))
  67. if "MASK]" not in raw_text:
  68. seq += [tokenizer.get_command(generation_mask)]
  69. raw_text += " " + generation_mask
  70. if not raw_text.endswith("MASK]"):
  71. seq = seq + [tokenizer.get_command("eos")]
  72. if mpu.get_model_parallel_rank() == 0:
  73. print("\nInput: {}\n".format(raw_text))
  74. if len(seq) > args.max_sequence_length:
  75. raise ValueError("text too long.")
  76. # generation
  77. is_english = isEnglish(raw_text)
  78. output_list = [seq]
  79. num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
  80. last_pos, answers, answers_with_style, blanks = (
  81. [0] * num_output,
  82. ["" for _ in range(num_output)],
  83. ["" for _ in range(num_output)],
  84. [[] for _ in range(num_output)],
  85. )
  86. # continually detect the first mark position
  87. while True:
  88. seq = output_list[0]
  89. # detect mask position
  90. mask_token = tokenizer.get_command(generation_mask)
  91. if mask_token not in seq:
  92. break
  93. mask_position = seq.index(mask_token)
  94. output_list = []
  95. input_seq = torch.cuda.LongTensor(
  96. [seq + [tokenizer.get_command("sop")]],
  97. device=args.device,
  98. )
  99. output, _ = batch_filling_sequence(
  100. model,
  101. input_seq,
  102. torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
  103. strategy=strategy,
  104. get_masks_and_position_ids=partial(
  105. get_masks_and_position_ids,
  106. mask_position=mask_position,
  107. max_gen_length=args.out_seq_length,
  108. gmask=use_gmask,
  109. position_encoding_2d=args.position_encoding_2d,
  110. ),
  111. )
  112. if isinstance(output, torch.Tensor): # different strategies
  113. output = output.tolist()
  114. output = output[0] # batch_size = 1
  115. output_list.extend(output)
  116. # clip -1s and fill back generated things into seq
  117. for i in range(len(output_list)):
  118. output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
  119. try:
  120. unfinished = output.index(-1)
  121. except ValueError:
  122. unfinished = len(output)
  123. if output[unfinished - 1] in strategy.end_tokens:
  124. unfinished -= 1
  125. bog = output.index(tokenizer.get_command("sop"))
  126. prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
  127. blank = tokenizer.detokenize(output[bog + 1 : unfinished])
  128. answers_with_style[i] += (
  129. prefix
  130. + (" " if is_english else "")
  131. + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
  132. + blank
  133. + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
  134. + (" " if is_english else "")
  135. )
  136. blanks[i].append(blank)
  137. last_pos[i] = mask_position + unfinished - (bog + 1)
  138. output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
  139. for i, output in enumerate(output_list):
  140. if output[-1] == tokenizer.get_command("eos"):
  141. output = output[:-1]
  142. answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
  143. answers[i] = tokenizer.detokenize(output)
  144. return answers, answers_with_style, blanks
  145. def main(args):
  146. model, tokenizer = initialize_model_and_tokenizer(args)
  147. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  148. if args.sampling_strategy == "BaseStrategy":
  149. strategy = BaseStrategy(
  150. batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
  151. )
  152. elif args.sampling_strategy == "BeamSearchStrategy":
  153. strategy = BeamSearchStrategy(
  154. 1,
  155. args.num_beams,
  156. length_penalty=args.length_penalty,
  157. consider_end=True,
  158. end_tokens=end_tokens,
  159. no_repeat_ngram_size=args.no_repeat_ngram_size,
  160. min_gen_length=args.min_gen_length,
  161. )
  162. else:
  163. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  164. def process(raw_text):
  165. if args.with_id:
  166. query_id, raw_text = raw_text.split("\t")
  167. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
  168. # save
  169. if args.with_id:
  170. full_path = os.path.join(args.output_path, query_id + ".txt")
  171. else:
  172. prefix = raw_text.replace("/", "")[:20]
  173. full_path = timed_name(prefix, ".txt", args.output_path)
  174. if mpu.get_model_parallel_rank() == 0:
  175. if args.print_all_beams and len(answers) > 1:
  176. for idx, answer_with_style in enumerate(answers_with_style):
  177. print(f"Output beam {idx}:", answer_with_style) # print the first.
  178. if len(answer_with_style) > 120:
  179. print("")
  180. else:
  181. print(f"Output:", answers_with_style[0]) # print the first.
  182. with open(full_path, "w", encoding="utf-8") as fout:
  183. for answer in answers:
  184. fout.write(answer + "\n")
  185. os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
  186. os.makedirs(args.output_path, exist_ok=True)
  187. generate_continually(process, args.input_source)
  188. if __name__ == "__main__":
  189. args = initialize(extra_args_provider=add_generation_specific_args)
  190. with torch.no_grad():
  191. main(args)