|
@@ -161,10 +161,9 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
|
|
return answers, answers_with_style, blanks
|
|
return answers, answers_with_style, blanks
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
def generate_continually(func, raw_text):
|
|
def generate_continually(func, raw_text):
|
|
if not raw_text:
|
|
if not raw_text:
|
|
- return 'Input should not be empty!'
|
|
|
|
|
|
+ return "Input should not be empty!"
|
|
try:
|
|
try:
|
|
start_time = time.time()
|
|
start_time = time.time()
|
|
answer = func(raw_text)
|
|
answer = func(raw_text)
|
|
@@ -173,10 +172,12 @@ def generate_continually(func, raw_text):
|
|
return answer
|
|
return answer
|
|
except (ValueError, FileNotFoundError) as e:
|
|
except (ValueError, FileNotFoundError) as e:
|
|
print(e)
|
|
print(e)
|
|
- return 'Error!'
|
|
|
|
|
|
+ return "Error!"
|
|
|
|
+
|
|
|
|
|
|
strategy = None
|
|
strategy = None
|
|
|
|
|
|
|
|
+
|
|
def main(args):
|
|
def main(args):
|
|
model, tokenizer = initialize_model_and_tokenizer(args)
|
|
model, tokenizer = initialize_model_and_tokenizer(args)
|
|
|
|
|
|
@@ -195,16 +196,55 @@ def main(args):
|
|
|
|
|
|
return answers[0]
|
|
return answers[0]
|
|
|
|
|
|
-
|
|
|
|
- def predict(text, seed=1234, out_seq_length=200, min_gen_length=20, sampling_strategy='BaseStrategy',
|
|
|
|
- num_beams=4, length_penalty=0.9, no_repeat_ngram_size=3,
|
|
|
|
- temperature=1, topk=1, topp=1):
|
|
|
|
|
|
+ def predict(
|
|
|
|
+ text,
|
|
|
|
+ seed=1234,
|
|
|
|
+ out_seq_length=200,
|
|
|
|
+ min_gen_length=20,
|
|
|
|
+ sampling_strategy="BaseStrategy",
|
|
|
|
+ num_beams=4,
|
|
|
|
+ length_penalty=0.9,
|
|
|
|
+ no_repeat_ngram_size=3,
|
|
|
|
+ temperature=1,
|
|
|
|
+ topk=1,
|
|
|
|
+ topp=1,
|
|
|
|
+ ):
|
|
|
|
|
|
global strategy
|
|
global strategy
|
|
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
if torch.distributed.get_rank() == 0:
|
|
- print('info', [text, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp])
|
|
|
|
- dist.broadcast_object_list([text, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp], src=0)
|
|
|
|
|
|
+ print(
|
|
|
|
+ "info",
|
|
|
|
+ [
|
|
|
|
+ text,
|
|
|
|
+ seed,
|
|
|
|
+ out_seq_length,
|
|
|
|
+ min_gen_length,
|
|
|
|
+ sampling_strategy,
|
|
|
|
+ num_beams,
|
|
|
|
+ length_penalty,
|
|
|
|
+ no_repeat_ngram_size,
|
|
|
|
+ temperature,
|
|
|
|
+ topk,
|
|
|
|
+ topp,
|
|
|
|
+ ],
|
|
|
|
+ )
|
|
|
|
+ dist.broadcast_object_list(
|
|
|
|
+ [
|
|
|
|
+ text,
|
|
|
|
+ seed,
|
|
|
|
+ out_seq_length,
|
|
|
|
+ min_gen_length,
|
|
|
|
+ sampling_strategy,
|
|
|
|
+ num_beams,
|
|
|
|
+ length_penalty,
|
|
|
|
+ no_repeat_ngram_size,
|
|
|
|
+ temperature,
|
|
|
|
+ topk,
|
|
|
|
+ topp,
|
|
|
|
+ ],
|
|
|
|
+ src=0,
|
|
|
|
+ )
|
|
|
|
|
|
args.seed = seed
|
|
args.seed = seed
|
|
args.out_seq_length = out_seq_length
|
|
args.out_seq_length = out_seq_length
|
|
@@ -237,11 +277,11 @@ def main(args):
|
|
return generate_continually(process, text)
|
|
return generate_continually(process, text)
|
|
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
if torch.distributed.get_rank() == 0:
|
|
- en_fil = ['The Starry Night is an oil-on-canvas painting by [MASK] in June 1889.']
|
|
|
|
- en_gen = ['Eight planets in solar system are [gMASK]']
|
|
|
|
- ch_fil = ['凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。']
|
|
|
|
- ch_gen = ['三亚位于海南岛的最南端,是中国最南部的热带滨海旅游城市 [gMASK]']
|
|
|
|
- en_to_ch = ['Pencil in Chinese is [MASK].']
|
|
|
|
|
|
+ en_fil = ["The Starry Night is an oil-on-canvas painting by [MASK] in June 1889."]
|
|
|
|
+ en_gen = ["Eight planets in solar system are [gMASK]"]
|
|
|
|
+ ch_fil = ["凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。"]
|
|
|
|
+ ch_gen = ["三亚位于海南岛的最南端,是中国最南部的热带滨海旅游城市 [gMASK]"]
|
|
|
|
+ en_to_ch = ["Pencil in Chinese is [MASK]."]
|
|
ch_to_en = ['"我思故我在"的英文是"[MASK]"。']
|
|
ch_to_en = ['"我思故我在"的英文是"[MASK]"。']
|
|
|
|
|
|
examples = [en_fil, en_gen, ch_fil, ch_gen, en_to_ch, ch_to_en]
|
|
examples = [en_fil, en_gen, ch_fil, ch_gen, en_to_ch, ch_to_en]
|
|
@@ -253,28 +293,36 @@ def main(args):
|
|
GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
|
|
GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
|
|
|
|
|
|
Note: We suspect that there is a bug in the current FasterTransformer INT4 implementation that leads to gaps in generations compared to the FP16 model (e.g. more repititions), which we are troubleshooting, and the current model output is **for reference only**
|
|
Note: We suspect that there is a bug in the current FasterTransformer INT4 implementation that leads to gaps in generations compared to the FP16 model (e.g. more repititions), which we are troubleshooting, and the current model output is **for reference only**
|
|
- """)
|
|
|
|
|
|
+ """
|
|
|
|
+ )
|
|
|
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Column():
|
|
- model_input = gr.Textbox(lines=7, placeholder='Input something in English or Chinese', label='Input')
|
|
|
|
|
|
+ model_input = gr.Textbox(
|
|
|
|
+ lines=7, placeholder="Input something in English or Chinese", label="Input"
|
|
|
|
+ )
|
|
with gr.Row():
|
|
with gr.Row():
|
|
gen = gr.Button("Generate")
|
|
gen = gr.Button("Generate")
|
|
clr = gr.Button("Clear")
|
|
clr = gr.Button("Clear")
|
|
-
|
|
|
|
- outputs = gr.Textbox(lines=7, label='Output')
|
|
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+ outputs = gr.Textbox(lines=7, label="Output")
|
|
|
|
+
|
|
gr.Markdown(
|
|
gr.Markdown(
|
|
"""
|
|
"""
|
|
Generation Parameter
|
|
Generation Parameter
|
|
- """)
|
|
|
|
|
|
+ """
|
|
|
|
+ )
|
|
with gr.Row():
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Column():
|
|
- seed = gr.Slider(maximum=100000, value=1234, step=1, label='Seed')
|
|
|
|
- out_seq_length = gr.Slider(maximum=512, value=128, minimum=32, step=1, label='Output Sequence Length')
|
|
|
|
|
|
+ seed = gr.Slider(maximum=100000, value=1234, step=1, label="Seed")
|
|
|
|
+ out_seq_length = gr.Slider(
|
|
|
|
+ maximum=512, value=128, minimum=32, step=1, label="Output Sequence Length"
|
|
|
|
+ )
|
|
with gr.Column():
|
|
with gr.Column():
|
|
- min_gen_length = gr.Slider(maximum=64, value=0, step=1, label='Min Generate Length')
|
|
|
|
- sampling_strategy = gr.Radio(choices=['BeamSearchStrategy', 'BaseStrategy'], value='BaseStrategy', label='Search Strategy')
|
|
|
|
|
|
+ min_gen_length = gr.Slider(maximum=64, value=0, step=1, label="Min Generate Length")
|
|
|
|
+ sampling_strategy = gr.Radio(
|
|
|
|
+ choices=["BeamSearchStrategy", "BaseStrategy"], value="BaseStrategy", label="Search Strategy"
|
|
|
|
+ )
|
|
|
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Column():
|
|
@@ -282,32 +330,49 @@ def main(args):
|
|
gr.Markdown(
|
|
gr.Markdown(
|
|
"""
|
|
"""
|
|
BeamSearchStrategy
|
|
BeamSearchStrategy
|
|
- """)
|
|
|
|
- num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label='Number of Beams')
|
|
|
|
- length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label='Length Penalty')
|
|
|
|
- no_repeat_ngram_size = gr.Slider(maximum=5, value=3, minimum=1, step=1, label='No Repeat Ngram Size')
|
|
|
|
|
|
+ """
|
|
|
|
+ )
|
|
|
|
+ num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label="Number of Beams")
|
|
|
|
+ length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label="Length Penalty")
|
|
|
|
+ no_repeat_ngram_size = gr.Slider(
|
|
|
|
+ maximum=5, value=3, minimum=1, step=1, label="No Repeat Ngram Size"
|
|
|
|
+ )
|
|
with gr.Column():
|
|
with gr.Column():
|
|
# base search
|
|
# base search
|
|
gr.Markdown(
|
|
gr.Markdown(
|
|
"""
|
|
"""
|
|
BaseStrategy
|
|
BaseStrategy
|
|
- """)
|
|
|
|
- temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label='Temperature')
|
|
|
|
- topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K')
|
|
|
|
- topp = gr.Slider(maximum=1, value=0.7, minimum=0, label='Top P')
|
|
|
|
-
|
|
|
|
- inputs = [model_input, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp]
|
|
|
|
|
|
+ """
|
|
|
|
+ )
|
|
|
|
+ temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label="Temperature")
|
|
|
|
+ topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label="Top K")
|
|
|
|
+ topp = gr.Slider(maximum=1, value=0.7, minimum=0, label="Top P")
|
|
|
|
+
|
|
|
|
+ inputs = [
|
|
|
|
+ model_input,
|
|
|
|
+ seed,
|
|
|
|
+ out_seq_length,
|
|
|
|
+ min_gen_length,
|
|
|
|
+ sampling_strategy,
|
|
|
|
+ num_beams,
|
|
|
|
+ length_penalty,
|
|
|
|
+ no_repeat_ngram_size,
|
|
|
|
+ temperature,
|
|
|
|
+ topk,
|
|
|
|
+ topp,
|
|
|
|
+ ]
|
|
gen.click(fn=predict, inputs=inputs, outputs=outputs)
|
|
gen.click(fn=predict, inputs=inputs, outputs=outputs)
|
|
clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
|
|
clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
|
|
-
|
|
|
|
|
|
+
|
|
gr_examples = gr.Examples(examples=examples, inputs=model_input)
|
|
gr_examples = gr.Examples(examples=examples, inputs=model_input)
|
|
-
|
|
|
|
|
|
+
|
|
gr.Markdown(
|
|
gr.Markdown(
|
|
"""
|
|
"""
|
|
Disclaimer inspired from [BLOOM](https://huggingface.co/spaces/bigscience/bloom-book)
|
|
Disclaimer inspired from [BLOOM](https://huggingface.co/spaces/bigscience/bloom-book)
|
|
|
|
|
|
GLM-130B was trained on web-crawled data, so it's hard to predict how GLM-130B will respond to particular prompts; harmful or otherwise offensive content may occur without warning. We prohibit users from knowingly generating or allowing others to knowingly generate harmful content, including Hateful, Harassment, Violence, Adult, Political, Deception, etc.
|
|
GLM-130B was trained on web-crawled data, so it's hard to predict how GLM-130B will respond to particular prompts; harmful or otherwise offensive content may occur without warning. We prohibit users from knowingly generating or allowing others to knowingly generate harmful content, including Hateful, Harassment, Violence, Adult, Political, Deception, etc.
|
|
- """)
|
|
|
|
|
|
+ """
|
|
|
|
+ )
|
|
|
|
|
|
demo.launch(share=True)
|
|
demo.launch(share=True)
|
|
else:
|
|
else:
|
|
@@ -315,11 +380,33 @@ def main(args):
|
|
info = [None, None, None, None, None, None, None, None, None, None, None]
|
|
info = [None, None, None, None, None, None, None, None, None, None, None]
|
|
dist.broadcast_object_list(info, src=0)
|
|
dist.broadcast_object_list(info, src=0)
|
|
|
|
|
|
- text, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp = info
|
|
|
|
-
|
|
|
|
- predict(text, seed, out_seq_length, min_gen_length, sampling_strategy,
|
|
|
|
- num_beams, length_penalty, no_repeat_ngram_size,
|
|
|
|
- temperature, topk, topp)
|
|
|
|
|
|
+ (
|
|
|
|
+ text,
|
|
|
|
+ seed,
|
|
|
|
+ out_seq_length,
|
|
|
|
+ min_gen_length,
|
|
|
|
+ sampling_strategy,
|
|
|
|
+ num_beams,
|
|
|
|
+ length_penalty,
|
|
|
|
+ no_repeat_ngram_size,
|
|
|
|
+ temperature,
|
|
|
|
+ topk,
|
|
|
|
+ topp,
|
|
|
|
+ ) = info
|
|
|
|
+
|
|
|
|
+ predict(
|
|
|
|
+ text,
|
|
|
|
+ seed,
|
|
|
|
+ out_seq_length,
|
|
|
|
+ min_gen_length,
|
|
|
|
+ sampling_strategy,
|
|
|
|
+ num_beams,
|
|
|
|
+ length_penalty,
|
|
|
|
+ no_repeat_ngram_size,
|
|
|
|
+ temperature,
|
|
|
|
+ topk,
|
|
|
|
+ topp,
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|