Ver código fonte

Fix top_p argument in generate.py

Zhengxiao Du 3 anos atrás
pai
commit
3bb0f456d1
1 arquivos alterados com 2 adições e 1 exclusões
  1. 2 1
      generate.py

+ 2 - 1
generate.py

@@ -159,7 +159,8 @@ def main(args):
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
 
     if args.sampling_strategy == "BaseStrategy":
-        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
+        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                end_tokens=end_tokens)
     elif args.sampling_strategy == "BeamSearchStrategy":
         strategy = BeamSearchStrategy(
             1,