benchmark.py 872 B

1234567891011121314151617181920
  1. import torch
  2. import time
  3. from initialize import initialize, initialize_model_and_tokenizer
  4. if __name__ == "__main__":
  5. args = initialize(extra_args_provider=lambda parser: None)
  6. model, tokenizer = initialize_model_and_tokenizer(args)
  7. for seq_len in [512, 1024, 2048]:
  8. torch.distributed.barrier()
  9. start = time.time()
  10. with torch.no_grad():
  11. _, *_ = model(
  12. torch.ones(1, seq_len, device=torch.cuda.current_device(), dtype=torch.int64),
  13. torch.arange(seq_len, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1),
  14. torch.randn(1, 1, seq_len, seq_len, device=torch.cuda.current_device()) < 0.5,
  15. )
  16. torch.distributed.barrier()
  17. if torch.distributed.get_rank() == 0:
  18. print(f"Encode {seq_len}: {(time.time() - start) * 1000:.2f} ms")