2
0

server.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import os
  2. import torch
  3. import stat
  4. import re
  5. from functools import partial
  6. from typing import List, Tuple
  7. from SwissArmyTransformer import mpu
  8. from evaluation.model import batch_filling_sequence
  9. from generation import BeamSearchStrategy, BaseStrategy
  10. from SwissArmyTransformer.generation.utils import timed_name
  11. from initialize import initialize, initialize_model_and_tokenizer
  12. import torch.distributed as dist
  13. import time
  14. import gradio as gr
  15. def add_generation_specific_args(parser):
  16. parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
  17. parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
  18. parser.add_argument(
  19. "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
  20. )
  21. def isEnglish(s):
  22. try:
  23. s.encode(encoding="utf-8").decode("ascii")
  24. except UnicodeDecodeError:
  25. return False
  26. else:
  27. return True
  28. def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
  29. context_length = seq.shape[1]
  30. tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
  31. attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
  32. attention_mask.tril_()
  33. attention_mask[..., : context_length - 1] = 1
  34. attention_mask.unsqueeze_(1)
  35. attention_mask = (attention_mask < 0.5).bool()
  36. position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
  37. if not gmask:
  38. position_ids[context_length - 1 :] = mask_position
  39. position_ids = position_ids.unsqueeze(0)
  40. return tokens, attention_mask, position_ids
  41. def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
  42. # add MASK
  43. generation_mask = "[gMASK]"
  44. if "[MASK]" in raw_text:
  45. generation_mask = "[MASK]"
  46. elif "[sMASK]" in raw_text:
  47. generation_mask = "[sMASK]"
  48. use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
  49. mask_pattern = r"\[[sg]?MASK\]"
  50. text_list = re.split(mask_pattern, raw_text)
  51. pattern_list = re.compile(mask_pattern).findall(raw_text)
  52. seq = []
  53. for i in range(len(pattern_list)):
  54. pattern = pattern_list[i]
  55. sub_text = text_list[i]
  56. seq.extend(tokenizer.tokenize(sub_text))
  57. seq.append(tokenizer.get_command(pattern))
  58. seq.extend(tokenizer.tokenize(text_list[-1]))
  59. if "MASK]" not in raw_text:
  60. seq += [tokenizer.get_command(generation_mask)]
  61. raw_text += " " + generation_mask
  62. if not raw_text.endswith("MASK]"):
  63. seq = seq + [tokenizer.get_command("eos")]
  64. if mpu.get_model_parallel_rank() == 0:
  65. print("\nInput: {}\n".format(raw_text))
  66. if len(seq) > args.max_sequence_length:
  67. raise ValueError("text too long.")
  68. # generation
  69. is_english = isEnglish(raw_text)
  70. output_list = [seq]
  71. num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
  72. last_pos, answers, answers_with_style, blanks = (
  73. [0] * num_output,
  74. ["" for _ in range(num_output)],
  75. ["" for _ in range(num_output)],
  76. [[] for _ in range(num_output)],
  77. )
  78. # continually detect the first mark position
  79. while True:
  80. seq = output_list[0]
  81. # detect mask position
  82. mask_token = tokenizer.get_command(generation_mask)
  83. if mask_token not in seq:
  84. break
  85. mask_position = seq.index(mask_token)
  86. output_list = []
  87. input_seq = torch.cuda.LongTensor(
  88. [seq + [tokenizer.get_command("sop")]],
  89. device=args.device,
  90. )
  91. output, _ = batch_filling_sequence(
  92. model,
  93. input_seq,
  94. torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
  95. strategy=strategy,
  96. get_masks_and_position_ids=partial(
  97. get_masks_and_position_ids,
  98. mask_position=mask_position,
  99. max_gen_length=args.out_seq_length,
  100. gmask=use_gmask,
  101. ),
  102. )
  103. if isinstance(output, torch.Tensor): # different strategies
  104. output = output.tolist()
  105. output = output[0] # batch_size = 1
  106. output_list.extend(output)
  107. # clip -1s and fill back generated things into seq
  108. for i in range(len(output_list)):
  109. output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
  110. try:
  111. unfinished = output.index(-1)
  112. except ValueError:
  113. unfinished = len(output)
  114. if output[unfinished - 1] in strategy.end_tokens:
  115. unfinished -= 1
  116. bog = output.index(tokenizer.get_command("sop"))
  117. prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
  118. blank = tokenizer.detokenize(output[bog + 1 : unfinished])
  119. answers_with_style[i] += (
  120. prefix
  121. + (" " if is_english else "")
  122. + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
  123. + blank
  124. + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
  125. + (" " if is_english else "")
  126. )
  127. blanks[i].append(blank)
  128. last_pos[i] = mask_position + unfinished - (bog + 1)
  129. output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
  130. for i, output in enumerate(output_list):
  131. if output[-1] == tokenizer.get_command("eos"):
  132. output = output[:-1]
  133. answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
  134. answers[i] = tokenizer.detokenize(output)
  135. return answers, answers_with_style, blanks
  136. def generate_continually(func, raw_text):
  137. if not raw_text:
  138. return "Input should not be empty!"
  139. try:
  140. start_time = time.time()
  141. answer = func(raw_text)
  142. if torch.distributed.get_rank() == 0:
  143. print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
  144. return answer
  145. except (ValueError, FileNotFoundError) as e:
  146. print(e)
  147. return "Error!"
  148. strategy = None
  149. def main(args):
  150. model, tokenizer = initialize_model_and_tokenizer(args)
  151. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  152. def process(raw_text):
  153. global strategy
  154. if args.with_id:
  155. query_id, raw_text = raw_text.split("\t")
  156. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
  157. if torch.distributed.get_rank() == 0:
  158. print(answers)
  159. return answers[0]
  160. def predict(
  161. text,
  162. seed=1234,
  163. out_seq_length=200,
  164. min_gen_length=20,
  165. sampling_strategy="BaseStrategy",
  166. num_beams=4,
  167. length_penalty=0.9,
  168. no_repeat_ngram_size=3,
  169. temperature=1,
  170. topk=1,
  171. topp=1,
  172. ):
  173. global strategy
  174. if torch.distributed.get_rank() == 0:
  175. print(
  176. "info",
  177. [
  178. text,
  179. seed,
  180. out_seq_length,
  181. min_gen_length,
  182. sampling_strategy,
  183. num_beams,
  184. length_penalty,
  185. no_repeat_ngram_size,
  186. temperature,
  187. topk,
  188. topp,
  189. ],
  190. )
  191. dist.broadcast_object_list(
  192. [
  193. text,
  194. seed,
  195. out_seq_length,
  196. min_gen_length,
  197. sampling_strategy,
  198. num_beams,
  199. length_penalty,
  200. no_repeat_ngram_size,
  201. temperature,
  202. topk,
  203. topp,
  204. ],
  205. src=0,
  206. )
  207. args.seed = seed
  208. args.out_seq_length = out_seq_length
  209. args.min_gen_length = min_gen_length
  210. args.sampling_strategy = sampling_strategy
  211. args.num_beams = num_beams
  212. args.length_penalty = length_penalty
  213. args.no_repeat_ngram_size = no_repeat_ngram_size
  214. args.temperature = temperature
  215. args.top_k = topk
  216. args.top_p = topp
  217. if args.sampling_strategy == "BaseStrategy":
  218. strategy = BaseStrategy(
  219. batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
  220. )
  221. elif args.sampling_strategy == "BeamSearchStrategy":
  222. strategy = BeamSearchStrategy(
  223. batch_size=1,
  224. num_beams=args.num_beams,
  225. length_penalty=args.length_penalty,
  226. consider_end=True,
  227. end_tokens=end_tokens,
  228. no_repeat_ngram_size=args.no_repeat_ngram_size,
  229. min_gen_length=args.min_gen_length,
  230. )
  231. else:
  232. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  233. return generate_continually(process, text)
  234. if torch.distributed.get_rank() == 0:
  235. en_fil = ["The Starry Night is an oil-on-canvas painting by [MASK] in June 1889."]
  236. en_gen = ["Eight planets in solar system are [gMASK]"]
  237. ch_fil = ["凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。"]
  238. ch_gen = ["三亚位于海南岛的最南端,是中国最南部的热带滨海旅游城市 [gMASK]"]
  239. en_to_ch = ["Pencil in Chinese is [MASK]."]
  240. ch_to_en = ['"我思故我在"的英文是"[MASK]"。']
  241. examples = [en_fil, en_gen, ch_fil, ch_gen, en_to_ch, ch_to_en]
  242. with gr.Blocks() as demo:
  243. gr.Markdown(
  244. """
  245. An Open Bilingual Pre-Trained Model. [Visit our github repo](https://github.com/THUDM/GLM-130B)
  246. GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
  247. Note: We suspect that there is a bug in the current FasterTransformer INT4 implementation that leads to gaps in generations compared to the FP16 model (e.g. more repititions), which we are troubleshooting, and the current model output is **for reference only**
  248. """
  249. )
  250. with gr.Row():
  251. with gr.Column():
  252. model_input = gr.Textbox(
  253. lines=7, placeholder="Input something in English or Chinese", label="Input"
  254. )
  255. with gr.Row():
  256. gen = gr.Button("Generate")
  257. clr = gr.Button("Clear")
  258. outputs = gr.Textbox(lines=7, label="Output")
  259. gr.Markdown(
  260. """
  261. Generation Parameter
  262. """
  263. )
  264. with gr.Row():
  265. with gr.Column():
  266. seed = gr.Slider(maximum=100000, value=1234, step=1, label="Seed")
  267. out_seq_length = gr.Slider(
  268. maximum=512, value=128, minimum=32, step=1, label="Output Sequence Length"
  269. )
  270. with gr.Column():
  271. min_gen_length = gr.Slider(maximum=64, value=0, step=1, label="Min Generate Length")
  272. sampling_strategy = gr.Radio(
  273. choices=["BeamSearchStrategy", "BaseStrategy"], value="BaseStrategy", label="Search Strategy"
  274. )
  275. with gr.Row():
  276. with gr.Column():
  277. # beam search
  278. gr.Markdown(
  279. """
  280. BeamSearchStrategy
  281. """
  282. )
  283. num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label="Number of Beams")
  284. length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label="Length Penalty")
  285. no_repeat_ngram_size = gr.Slider(
  286. maximum=5, value=3, minimum=1, step=1, label="No Repeat Ngram Size"
  287. )
  288. with gr.Column():
  289. # base search
  290. gr.Markdown(
  291. """
  292. BaseStrategy
  293. """
  294. )
  295. temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label="Temperature")
  296. topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label="Top K")
  297. topp = gr.Slider(maximum=1, value=0.7, minimum=0, label="Top P")
  298. inputs = [
  299. model_input,
  300. seed,
  301. out_seq_length,
  302. min_gen_length,
  303. sampling_strategy,
  304. num_beams,
  305. length_penalty,
  306. no_repeat_ngram_size,
  307. temperature,
  308. topk,
  309. topp,
  310. ]
  311. gen.click(fn=predict, inputs=inputs, outputs=outputs)
  312. clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
  313. gr_examples = gr.Examples(examples=examples, inputs=model_input)
  314. gr.Markdown(
  315. """
  316. Disclaimer inspired from [BLOOM](https://huggingface.co/spaces/bigscience/bloom-book)
  317. GLM-130B was trained on web-crawled data, so it's hard to predict how GLM-130B will respond to particular prompts; harmful or otherwise offensive content may occur without warning. We prohibit users from knowingly generating or allowing others to knowingly generate harmful content, including Hateful, Harassment, Violence, Adult, Political, Deception, etc.
  318. """
  319. )
  320. demo.launch(share=True)
  321. else:
  322. while True:
  323. info = [None, None, None, None, None, None, None, None, None, None, None]
  324. dist.broadcast_object_list(info, src=0)
  325. (
  326. text,
  327. seed,
  328. out_seq_length,
  329. min_gen_length,
  330. sampling_strategy,
  331. num_beams,
  332. length_penalty,
  333. no_repeat_ngram_size,
  334. temperature,
  335. topk,
  336. topp,
  337. ) = info
  338. predict(
  339. text,
  340. seed,
  341. out_seq_length,
  342. min_gen_length,
  343. sampling_strategy,
  344. num_beams,
  345. length_penalty,
  346. no_repeat_ngram_size,
  347. temperature,
  348. topk,
  349. topp,
  350. )
  351. if __name__ == "__main__":
  352. args = initialize(extra_args_provider=add_generation_specific_args)
  353. with torch.no_grad():
  354. main(args)