123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- import time
- import torch
- import torch.distributed as dist
- import gradio as gr
- from generation import BeamSearchStrategy, BaseStrategy
- from initialize import initialize, initialize_model_and_tokenizer
- from generate import add_generation_specific_args, fill_blanks
- def generate_continually(func, raw_text):
- if not raw_text:
- return "Input should not be empty!"
- try:
- start_time = time.time()
- answer = func(raw_text)
- if torch.distributed.get_rank() == 0:
- print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
- return answer
- except (ValueError, FileNotFoundError) as e:
- print(e)
- return "Error!"
- strategy = None
- def main(args):
- model, tokenizer = initialize_model_and_tokenizer(args)
- end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
- def process(raw_text):
- global strategy
- if args.with_id:
- query_id, raw_text = raw_text.split("\t")
- answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy, args)
- if torch.distributed.get_rank() == 0:
- print(answers)
- return answers[0]
- def predict(
- text,
- seed=1234,
- out_seq_length=200,
- min_gen_length=20,
- sampling_strategy="BaseStrategy",
- num_beams=4,
- length_penalty=0.9,
- no_repeat_ngram_size=3,
- temperature=1,
- topk=1,
- topp=1,
- ):
- global strategy
- if torch.distributed.get_rank() == 0:
- print(
- "info",
- [
- text,
- seed,
- out_seq_length,
- min_gen_length,
- sampling_strategy,
- num_beams,
- length_penalty,
- no_repeat_ngram_size,
- temperature,
- topk,
- topp,
- ],
- )
- dist.broadcast_object_list(
- [
- text,
- seed,
- out_seq_length,
- min_gen_length,
- sampling_strategy,
- num_beams,
- length_penalty,
- no_repeat_ngram_size,
- temperature,
- topk,
- topp,
- ],
- src=0,
- )
- args.seed = seed
- args.out_seq_length = out_seq_length
- args.min_gen_length = min_gen_length
- args.sampling_strategy = sampling_strategy
- args.num_beams = num_beams
- args.length_penalty = length_penalty
- args.no_repeat_ngram_size = no_repeat_ngram_size
- args.temperature = temperature
- args.top_k = topk
- args.top_p = topp
- if args.sampling_strategy == "BaseStrategy":
- strategy = BaseStrategy(
- batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
- )
- elif args.sampling_strategy == "BeamSearchStrategy":
- strategy = BeamSearchStrategy(
- batch_size=1,
- num_beams=args.num_beams,
- length_penalty=args.length_penalty,
- consider_end=True,
- end_tokens=end_tokens,
- no_repeat_ngram_size=args.no_repeat_ngram_size,
- min_gen_length=args.min_gen_length,
- )
- else:
- raise ValueError(f"unknown strategy {args.sampling_strategy}")
- return generate_continually(process, text)
- if torch.distributed.get_rank() == 0:
- en_fil = ["The Starry Night is an oil-on-canvas painting by [MASK] in June 1889."]
- en_gen = ["Eight planets in solar system are [gMASK]"]
- ch_fil = ["凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。"]
- ch_gen = ["三亚位于海南岛的最南端,是中国最南部的热带滨海旅游城市 [gMASK]"]
- en_to_ch = ["Pencil in Chinese is [MASK]."]
- ch_to_en = ['"我思故我在"的英文是"[MASK]"。']
- examples = [en_fil, en_gen, ch_fil, ch_gen, en_to_ch, ch_to_en]
- with gr.Blocks() as demo:
- gr.Markdown(
- """
- An Open Bilingual Pre-Trained Model. [Visit our github repo](https://github.com/THUDM/GLM-130B)
- GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
- """
- )
- with gr.Row():
- with gr.Column():
- model_input = gr.Textbox(
- lines=7, placeholder="Input something in English or Chinese", label="Input"
- )
- with gr.Row():
- gen = gr.Button("Generate")
- clr = gr.Button("Clear")
- outputs = gr.Textbox(lines=7, label="Output")
- gr.Markdown(
- """
- Generation Parameter
- """
- )
- with gr.Row():
- with gr.Column():
- seed = gr.Slider(maximum=100000, value=1234, step=1, label="Seed")
- out_seq_length = gr.Slider(
- maximum=512, value=128, minimum=32, step=1, label="Output Sequence Length"
- )
- with gr.Column():
- min_gen_length = gr.Slider(maximum=64, value=0, step=1, label="Min Generate Length")
- sampling_strategy = gr.Radio(
- choices=["BeamSearchStrategy", "BaseStrategy"], value="BaseStrategy", label="Search Strategy"
- )
- with gr.Row():
- with gr.Column():
- # beam search
- gr.Markdown(
- """
- BeamSearchStrategy
- """
- )
- num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label="Number of Beams")
- length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label="Length Penalty")
- no_repeat_ngram_size = gr.Slider(
- maximum=5, value=3, minimum=1, step=1, label="No Repeat Ngram Size"
- )
- with gr.Column():
- # base search
- gr.Markdown(
- """
- BaseStrategy
- """
- )
- temperature = gr.Slider(maximum=1, value=1.0, minimum=0, label="Temperature")
- topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label="Top K")
- topp = gr.Slider(maximum=1, value=0.7, minimum=0, label="Top P")
- inputs = [
- model_input,
- seed,
- out_seq_length,
- min_gen_length,
- sampling_strategy,
- num_beams,
- length_penalty,
- no_repeat_ngram_size,
- temperature,
- topk,
- topp,
- ]
- gen.click(fn=predict, inputs=inputs, outputs=outputs)
- clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
- gr_examples = gr.Examples(examples=examples, inputs=model_input)
- demo.launch(share=True)
- else:
- while True:
- info = [None, None, None, None, None, None, None, None, None, None, None]
- dist.broadcast_object_list(info, src=0)
- (
- text,
- seed,
- out_seq_length,
- min_gen_length,
- sampling_strategy,
- num_beams,
- length_penalty,
- no_repeat_ngram_size,
- temperature,
- topk,
- topp,
- ) = info
- predict(
- text,
- seed,
- out_seq_length,
- min_gen_length,
- sampling_strategy,
- num_beams,
- length_penalty,
- no_repeat_ngram_size,
- temperature,
- topk,
- topp,
- )
- if __name__ == "__main__":
- args = initialize(extra_args_provider=add_generation_specific_args)
- with torch.no_grad():
- main(args)
|