run_with_slurm.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import argparse
  2. import logging
  3. import os
  4. import platform
  5. import shutil
  6. import subprocess
  7. import time
  8. from pathlib import Path
  9. logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
  10. logging.basicConfig(
  11. level=logging.INFO,
  12. format=logging_format,
  13. )
  14. logger = logging.getLogger("train")
  15. def init_parser() -> argparse.ArgumentParser:
  16. parser = argparse.ArgumentParser(description="Run M4T training")
  17. parser.add_argument(
  18. "-w",
  19. type=Path,
  20. required=True,
  21. help="Work directory, where logs, checkpoints and core dumps will be stored",
  22. )
  23. parser.add_argument(
  24. "-p",
  25. type=Path,
  26. required=True,
  27. help="Training workflow config",
  28. )
  29. parser.add_argument(
  30. "-n",
  31. type=int,
  32. required=False,
  33. default=1,
  34. help="Number of training nodes",
  35. )
  36. parser.add_argument(
  37. "-c",
  38. type=str,
  39. required=False,
  40. default="seamless",
  41. help="Cluster partitions to use",
  42. )
  43. parser.add_argument(
  44. "-j",
  45. type=str,
  46. required=False,
  47. default="train",
  48. help="Slurm job name",
  49. )
  50. return parser
  51. def prepare_sbatch_config(
  52. job_name: str,
  53. params_file: str,
  54. num_nodes: int,
  55. partitions: str,
  56. work_dir: str,
  57. cluster_logs_dir: str,
  58. run_script: str,
  59. ) -> str:
  60. return f"""#!/bin/bash
  61. ## job name
  62. #SBATCH --job-name={job_name}
  63. ## filename for job standard output (stdout)
  64. ## %j is the job id, %u is the user id
  65. #SBATCH --output={cluster_logs_dir}/%j.out
  66. ## filename for job standard error output (stderr)
  67. #SBATCH --error={cluster_logs_dir}/%j.err
  68. ## partition name
  69. #SBATCH --partition={partitions}
  70. ## number of nodes
  71. #SBATCH --nodes={num_nodes}
  72. ## number of nodes
  73. #SBATCH --gpus-per-node=8
  74. ## number of cpus per task
  75. #SBATCH --cpus-per-task=96
  76. #SBATCH --gres=gpu:8
  77. ## number of tasks per node
  78. #SBATCH --ntasks-per-node=1
  79. ## amount of mem
  80. #SBATCH --mem 50G
  81. ## amount of time in minutes
  82. #SBATCH --time 2400
  83. set -x
  84. export WANDB_DISABLED=true
  85. export HDF5_USE_FILE_LOCKING='FALSE'
  86. export PARENT=`/bin/hostname -s`
  87. export MPORT=24198
  88. export CHILDREN=`scontrol show hostnames $SLURM_JOB_NODELIST | grep -v $PARENT`
  89. export HOSTLIST="$PARENT $CHILDREN"
  90. echo $HOSTLIST
  91. export WORLD_SIZE=$SLURM_NTASKS
  92. srun --label bash -c 'which python && torchrun \\
  93. --nproc_per_node=8 \\
  94. --nnodes=$SLURM_JOB_NUM_NODES \\
  95. --node_rank="$SLURM_PROCID" \\
  96. --master_addr="$PARENT" \\
  97. --master_port="$MPORT" \\
  98. --log-dir={cluster_logs_dir} \\
  99. {run_script} --params {params_file} --wd {work_dir}'
  100. """
  101. def main() -> None:
  102. args = init_parser().parse_args()
  103. params_file = args.p
  104. num_nodes = args.n
  105. partitions = args.c
  106. work_dir = args.w
  107. job_name = args.j
  108. assert job_name is not None
  109. assert len(job_name.split()) == 1, "spaces in job name not allowed"
  110. assert partitions and len(partitions.split()) == 1, "spaces in partitions not allowed"
  111. assert os.path.exists(params_file), "config file is missing"
  112. training_script_path = os.path.join(os.path.dirname(__file__), "run_training.py")
  113. assert os.path.exists(training_script_path), f"Can't find training script {training_script_path}"
  114. assert num_nodes > 0
  115. if not os.path.exists(work_dir):
  116. logger.info(f"Creating workdir {work_dir}")
  117. os.makedirs(work_dir)
  118. cluster_logs_dir = os.path.join(work_dir, "cluster_logs")
  119. if os.path.exists(cluster_logs_dir):
  120. logger.info(f"Clearing cluster logs dir {cluster_logs_dir}")
  121. shutil.rmtree(cluster_logs_dir)
  122. os.makedirs(cluster_logs_dir)
  123. config_text = prepare_sbatch_config(
  124. job_name=job_name,
  125. params_file=params_file,
  126. num_nodes=num_nodes,
  127. partitions=partitions,
  128. work_dir=work_dir,
  129. cluster_logs_dir=cluster_logs_dir,
  130. run_script=training_script_path,
  131. )
  132. logger.info(f"SBATCH config to launch: \n{config_text}")
  133. fname = f"{int(time.time())}_sbatch.sh"
  134. config_path = os.path.join(work_dir, fname)
  135. with open(config_path, "w") as fp_out:
  136. fp_out.write(config_text)
  137. logger.info(f"Saved to {config_path}")
  138. command = f"sbatch {config_path}"
  139. logger.info(f"Executing command: '{command}'")
  140. subprocess.Popen(command, shell=True).communicate()
  141. logger.info(f"Train log: {os.path.join(work_dir, 'train_log.txt')}")
  142. if __name__ == "__main__":
  143. main()