|
@@ -159,7 +159,8 @@ def main(args):
|
|
|
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
|
|
|
|
|
|
if args.sampling_strategy == "BaseStrategy":
|
|
|
- strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
|
|
|
+ strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
|
|
|
+ end_tokens=end_tokens)
|
|
|
elif args.sampling_strategy == "BeamSearchStrategy":
|
|
|
strategy = BeamSearchStrategy(
|
|
|
1,
|