Forráskód Böngészése

Fix gradio server

Sengxian 2 éve
szülő
commit
e260237347
2 módosított fájl, 9 hozzáadás és 19 törlés
  1. 7 7
      generate.py
  2. 2 12
      server.py

+ 7 - 7
generate.py

@@ -6,7 +6,7 @@ import re
 from functools import partial
 from typing import List, Tuple
 
-from SwissArmyTransformer import mpu
+from SwissArmyTransformer import mpu, get_args
 from evaluation.model import batch_filling_sequence
 from generation import BeamSearchStrategy, BaseStrategy
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
@@ -59,7 +59,7 @@ def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False,
     return tokens, attention_mask, position_ids
 
 
-def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
+def fill_blanks(raw_text: str, model, tokenizer, strategy, args) -> Tuple[List[str], List[str], List[List[str]]]:
     # add MASK
     generation_mask = "[gMASK]"
     if "[MASK]" in raw_text:
@@ -168,7 +168,9 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
     return answers, answers_with_style, blanks
 
 
-def main(args):
+def main():
+    args = initialize(extra_args_provider=add_generation_specific_args)
+
     model, tokenizer = initialize_model_and_tokenizer(args)
 
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
@@ -194,7 +196,7 @@ def main(args):
         if args.with_id:
             query_id, raw_text = raw_text.split("\t")
 
-        answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
+        answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy, args)
 
         # save
         if args.with_id:
@@ -221,7 +223,5 @@ def main(args):
 
 
 if __name__ == "__main__":
-    args = initialize(extra_args_provider=add_generation_specific_args)
-
     with torch.no_grad():
-        main(args)
+        main()

+ 2 - 12
server.py

@@ -36,7 +36,7 @@ def main(args):
         if args.with_id:
             query_id, raw_text = raw_text.split("\t")
 
-        answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
+        answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy, args)
 
         if torch.distributed.get_rank() == 0:
             print(answers)
@@ -138,8 +138,6 @@ def main(args):
                 """
                 An Open Bilingual Pre-Trained Model. [Visit our github repo](https://github.com/THUDM/GLM-130B)
                 GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
-                
-                Note: We suspect that there is a bug in the current FasterTransformer INT4 implementation that leads to gaps in generations compared to the FP16 model (e.g. more repititions), which we are troubleshooting, and the current model output is **for reference only**
                 """
             )
 
@@ -191,7 +189,7 @@ def main(args):
                         BaseStrategy
                         """
                     )
-                    temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label="Temperature")
+                    temperature = gr.Slider(maximum=1, value=1.0, minimum=0, label="Temperature")
                     topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label="Top K")
                     topp = gr.Slider(maximum=1, value=0.7, minimum=0, label="Top P")
 
@@ -213,14 +211,6 @@ def main(args):
 
             gr_examples = gr.Examples(examples=examples, inputs=model_input)
 
-            gr.Markdown(
-                """
-                Disclaimer inspired from [BLOOM](https://huggingface.co/spaces/bigscience/bloom-book)
-                
-                GLM-130B was trained on web-crawled data, so it's hard to predict how GLM-130B will respond to particular prompts; harmful or otherwise offensive content may occur without warning. We prohibit users from knowingly generating or allowing others to knowingly generate harmful content, including Hateful, Harassment, Violence, Adult, Political, Deception, etc. 
-                """
-            )
-
         demo.launch(share=True)
     else:
         while True: