浏览代码

Fix top_p argument in generate.py

Zhengxiao Du 3 年之前
父节点
当前提交
3bb0f456d1
共有 1 个文件被更改,包括 2 次插入1 次删除
  1. 2 1
      generate.py

+ 2 - 1
generate.py

@@ -159,7 +159,8 @@ def main(args):
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
 
     if args.sampling_strategy == "BaseStrategy":
-        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
+        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                end_tokens=end_tokens)
     elif args.sampling_strategy == "BeamSearchStrategy":
         strategy = BeamSearchStrategy(
             1,