run_training.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import logging
  8. import os
  9. import platform
  10. import shutil
  11. import time
  12. from pathlib import Path
  13. from typing import List
  14. import torch
  15. import yaml
  16. from m4t_scripts.train import dataloader as _dataloader
  17. from m4t_scripts.train import dist_utils
  18. from m4t_scripts.train import model as _model
  19. from m4t_scripts.train import trainer as _trainer
  20. from m4t_scripts.train.configs import WorkflowParams
  21. logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
  22. logging.basicConfig(
  23. level=logging.INFO,
  24. format=logging_format,
  25. )
  26. logger = logging.getLogger("train")
  27. def init_parser() -> argparse.ArgumentParser:
  28. parser = argparse.ArgumentParser(description="Run M4T training")
  29. parser.add_argument(
  30. "--wd",
  31. type=Path,
  32. required=True,
  33. help="Work directory, where logs, checkpoints and core dumps will be stored",
  34. )
  35. parser.add_argument(
  36. "--params",
  37. type=Path,
  38. required=True,
  39. help="Config with training parameters",
  40. )
  41. return parser
  42. def run_training(parameters: WorkflowParams, work_dir: str, checkpoint_dir: str) -> None:
  43. logger.info(f"Workflow params: {parameters}")
  44. rank, world_size = dist_utils.get_rank(), dist_utils.get_world_size()
  45. logger.info(f"Rank: {rank}, world_size: {world_size}")
  46. assert torch.cuda.device_count() > 0, "GPU is not available"
  47. device = torch.device("cuda")
  48. float_dtype = _trainer.UnitYTrainer._get_float_dtype(parameters.training.float_dtype)
  49. logger.info(f"Device: {device}, float dtype: {float_dtype}")
  50. model = _model.ModelBuilder(config=parameters.model, dtype=float_dtype, device=device).build_model()
  51. logger.info(f"Model: {model}")
  52. train_data = _dataloader.UnityDataLoader(
  53. config=parameters.train_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
  54. )
  55. eval_data = _dataloader.UnityDataLoader(
  56. config=parameters.eval_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
  57. )
  58. trainer = _trainer.UnitYTrainer(
  59. model=model,
  60. params=parameters.training,
  61. train_data_loader=train_data,
  62. eval_data_loader=eval_data,
  63. chck_save_dir=checkpoint_dir,
  64. device=device,
  65. )
  66. trainer.run()
  67. def get_loggers() -> List[logging.Logger]:
  68. return [logger, _trainer.logger, _dataloader.logger, _model.logger, dist_utils.logger]
  69. def set_file_output_for_loggers(log_filename: str) -> None:
  70. handler = logging.FileHandler(filename=log_filename, mode="a", delay=False)
  71. formatter = logging.Formatter(logging_format)
  72. handler.setFormatter(formatter)
  73. for logger in get_loggers():
  74. logger.handlers.append(handler)
  75. def main() -> None:
  76. args = init_parser().parse_args()
  77. dist_utils.init_distributed(get_loggers())
  78. is_master = dist_utils.is_main_process()
  79. with open(args.params, "r") as fp_in:
  80. parameters = WorkflowParams.deserialize(yaml.load(fp_in, Loader=yaml.FullLoader))
  81. ts = str(int(time.time()))
  82. work_dir = args.wd
  83. checkpoint_dir = os.path.join(work_dir, "checkpoints")
  84. if not os.path.exists(checkpoint_dir) and is_master:
  85. logger.info(f"Creating checkpoint dir: {checkpoint_dir}")
  86. # checkpoint_dir is not going to be used before syncs downstream,
  87. # so don't expect racing condition, and don't run barrier
  88. os.makedirs(checkpoint_dir)
  89. config_path = os.path.join(work_dir, f"{ts}_config.yaml")
  90. # copy to work dir to keep a snapshot of workflow config
  91. if is_master:
  92. shutil.copy(args.params, config_path)
  93. log_path = os.path.join(work_dir, "train_log.txt")
  94. logger.info(f"Set logging to {log_path}")
  95. set_file_output_for_loggers(log_path)
  96. try:
  97. run_training(parameters=parameters, work_dir=work_dir, checkpoint_dir=checkpoint_dir)
  98. except Exception:
  99. # make sure that the stack tracke will be logged to log files
  100. logger.exception("Training failed")
  101. if __name__ == "__main__":
  102. main()