server.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import time
  2. import torch
  3. import torch.distributed as dist
  4. import gradio as gr
  5. from generation import BeamSearchStrategy, BaseStrategy
  6. from initialize import initialize, initialize_model_and_tokenizer
  7. from generate import add_generation_specific_args, fill_blanks
  8. def generate_continually(func, raw_text):
  9. if not raw_text:
  10. return "Input should not be empty!"
  11. try:
  12. start_time = time.time()
  13. answer = func(raw_text)
  14. if torch.distributed.get_rank() == 0:
  15. print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
  16. return answer
  17. except (ValueError, FileNotFoundError) as e:
  18. print(e)
  19. return "Error!"
  20. strategy = None
  21. def main(args):
  22. model, tokenizer = initialize_model_and_tokenizer(args)
  23. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  24. def process(raw_text):
  25. global strategy
  26. if args.with_id:
  27. query_id, raw_text = raw_text.split("\t")
  28. answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy, args)
  29. if torch.distributed.get_rank() == 0:
  30. print(answers)
  31. return answers[0]
  32. def predict(
  33. text,
  34. seed=1234,
  35. out_seq_length=200,
  36. min_gen_length=20,
  37. sampling_strategy="BaseStrategy",
  38. num_beams=4,
  39. length_penalty=0.9,
  40. no_repeat_ngram_size=3,
  41. temperature=1,
  42. topk=1,
  43. topp=1,
  44. ):
  45. global strategy
  46. if torch.distributed.get_rank() == 0:
  47. print(
  48. "info",
  49. [
  50. text,
  51. seed,
  52. out_seq_length,
  53. min_gen_length,
  54. sampling_strategy,
  55. num_beams,
  56. length_penalty,
  57. no_repeat_ngram_size,
  58. temperature,
  59. topk,
  60. topp,
  61. ],
  62. )
  63. dist.broadcast_object_list(
  64. [
  65. text,
  66. seed,
  67. out_seq_length,
  68. min_gen_length,
  69. sampling_strategy,
  70. num_beams,
  71. length_penalty,
  72. no_repeat_ngram_size,
  73. temperature,
  74. topk,
  75. topp,
  76. ],
  77. src=0,
  78. )
  79. args.seed = seed
  80. args.out_seq_length = out_seq_length
  81. args.min_gen_length = min_gen_length
  82. args.sampling_strategy = sampling_strategy
  83. args.num_beams = num_beams
  84. args.length_penalty = length_penalty
  85. args.no_repeat_ngram_size = no_repeat_ngram_size
  86. args.temperature = temperature
  87. args.top_k = topk
  88. args.top_p = topp
  89. if args.sampling_strategy == "BaseStrategy":
  90. strategy = BaseStrategy(
  91. batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
  92. )
  93. elif args.sampling_strategy == "BeamSearchStrategy":
  94. strategy = BeamSearchStrategy(
  95. batch_size=1,
  96. num_beams=args.num_beams,
  97. length_penalty=args.length_penalty,
  98. consider_end=True,
  99. end_tokens=end_tokens,
  100. no_repeat_ngram_size=args.no_repeat_ngram_size,
  101. min_gen_length=args.min_gen_length,
  102. )
  103. else:
  104. raise ValueError(f"unknown strategy {args.sampling_strategy}")
  105. return generate_continually(process, text)
  106. if torch.distributed.get_rank() == 0:
  107. en_fil = ["The Starry Night is an oil-on-canvas painting by [MASK] in June 1889."]
  108. en_gen = ["Eight planets in solar system are [gMASK]"]
  109. ch_fil = ["凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。"]
  110. ch_gen = ["三亚位于海南岛的最南端,是中国最南部的热带滨海旅游城市 [gMASK]"]
  111. en_to_ch = ["Pencil in Chinese is [MASK]."]
  112. ch_to_en = ['"我思故我在"的英文是"[MASK]"。']
  113. examples = [en_fil, en_gen, ch_fil, ch_gen, en_to_ch, ch_to_en]
  114. with gr.Blocks() as demo:
  115. gr.Markdown(
  116. """
  117. An Open Bilingual Pre-Trained Model. [Visit our github repo](https://github.com/THUDM/GLM-130B)
  118. GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
  119. """
  120. )
  121. with gr.Row():
  122. with gr.Column():
  123. model_input = gr.Textbox(
  124. lines=7, placeholder="Input something in English or Chinese", label="Input"
  125. )
  126. with gr.Row():
  127. gen = gr.Button("Generate")
  128. clr = gr.Button("Clear")
  129. outputs = gr.Textbox(lines=7, label="Output")
  130. gr.Markdown(
  131. """
  132. Generation Parameter
  133. """
  134. )
  135. with gr.Row():
  136. with gr.Column():
  137. seed = gr.Slider(maximum=100000, value=1234, step=1, label="Seed")
  138. out_seq_length = gr.Slider(
  139. maximum=512, value=128, minimum=32, step=1, label="Output Sequence Length"
  140. )
  141. with gr.Column():
  142. min_gen_length = gr.Slider(maximum=64, value=0, step=1, label="Min Generate Length")
  143. sampling_strategy = gr.Radio(
  144. choices=["BeamSearchStrategy", "BaseStrategy"], value="BaseStrategy", label="Search Strategy"
  145. )
  146. with gr.Row():
  147. with gr.Column():
  148. # beam search
  149. gr.Markdown(
  150. """
  151. BeamSearchStrategy
  152. """
  153. )
  154. num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label="Number of Beams")
  155. length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label="Length Penalty")
  156. no_repeat_ngram_size = gr.Slider(
  157. maximum=5, value=3, minimum=1, step=1, label="No Repeat Ngram Size"
  158. )
  159. with gr.Column():
  160. # base search
  161. gr.Markdown(
  162. """
  163. BaseStrategy
  164. """
  165. )
  166. temperature = gr.Slider(maximum=1, value=1.0, minimum=0, label="Temperature")
  167. topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label="Top K")
  168. topp = gr.Slider(maximum=1, value=0.7, minimum=0, label="Top P")
  169. inputs = [
  170. model_input,
  171. seed,
  172. out_seq_length,
  173. min_gen_length,
  174. sampling_strategy,
  175. num_beams,
  176. length_penalty,
  177. no_repeat_ngram_size,
  178. temperature,
  179. topk,
  180. topp,
  181. ]
  182. gen.click(fn=predict, inputs=inputs, outputs=outputs)
  183. clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
  184. gr_examples = gr.Examples(examples=examples, inputs=model_input)
  185. demo.launch(share=True)
  186. else:
  187. while True:
  188. info = [None, None, None, None, None, None, None, None, None, None, None]
  189. dist.broadcast_object_list(info, src=0)
  190. (
  191. text,
  192. seed,
  193. out_seq_length,
  194. min_gen_length,
  195. sampling_strategy,
  196. num_beams,
  197. length_penalty,
  198. no_repeat_ngram_size,
  199. temperature,
  200. topk,
  201. topp,
  202. ) = info
  203. predict(
  204. text,
  205. seed,
  206. out_seq_length,
  207. min_gen_length,
  208. sampling_strategy,
  209. num_beams,
  210. length_penalty,
  211. no_repeat_ngram_size,
  212. temperature,
  213. topk,
  214. topp,
  215. )
  216. if __name__ == "__main__":
  217. args = initialize(extra_args_provider=add_generation_specific_args)
  218. with torch.no_grad():
  219. main(args)