|
@@ -47,21 +47,35 @@ def init_parser() -> argparse.ArgumentParser:
|
|
|
return parser
|
|
|
|
|
|
|
|
|
-def run_training(parameters: WorkflowParams, work_dir: str, checkpoint_dir: str) -> None:
|
|
|
+def run_training(
|
|
|
+ parameters: WorkflowParams, work_dir: str, checkpoint_dir: str
|
|
|
+) -> None:
|
|
|
logger.info(f"Workflow params: {parameters}")
|
|
|
rank, world_size = dist_utils.get_rank(), dist_utils.get_world_size()
|
|
|
logger.info(f"Rank: {rank}, world_size: {world_size}")
|
|
|
assert torch.cuda.device_count() > 0, "GPU is not available"
|
|
|
device = torch.device("cuda")
|
|
|
- float_dtype = _trainer.UnitYTrainer._get_float_dtype(parameters.training.float_dtype)
|
|
|
+ float_dtype = _trainer.UnitYTrainer._get_float_dtype(
|
|
|
+ parameters.training.float_dtype
|
|
|
+ )
|
|
|
logger.info(f"Device: {device}, float dtype: {float_dtype}")
|
|
|
- model = _model.ModelBuilder(config=parameters.model, dtype=float_dtype, device=device).build_model()
|
|
|
+ model = _model.ModelBuilder(
|
|
|
+ config=parameters.model, dtype=float_dtype, device=device
|
|
|
+ ).build_model()
|
|
|
logger.info(f"Model: {model}")
|
|
|
train_data = _dataloader.UnityDataLoader(
|
|
|
- config=parameters.train_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
|
|
|
+ config=parameters.train_data,
|
|
|
+ rank=rank,
|
|
|
+ world_size=world_size,
|
|
|
+ target_device=device,
|
|
|
+ float_dtype=float_dtype,
|
|
|
)
|
|
|
eval_data = _dataloader.UnityDataLoader(
|
|
|
- config=parameters.eval_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
|
|
|
+ config=parameters.eval_data,
|
|
|
+ rank=rank,
|
|
|
+ world_size=world_size,
|
|
|
+ target_device=device,
|
|
|
+ float_dtype=float_dtype,
|
|
|
)
|
|
|
trainer = _trainer.UnitYTrainer(
|
|
|
model=model,
|
|
@@ -75,7 +89,13 @@ def run_training(parameters: WorkflowParams, work_dir: str, checkpoint_dir: str)
|
|
|
|
|
|
|
|
|
def get_loggers() -> List[logging.Logger]:
|
|
|
- return [logger, _trainer.logger, _dataloader.logger, _model.logger, dist_utils.logger]
|
|
|
+ return [
|
|
|
+ logger,
|
|
|
+ _trainer.logger,
|
|
|
+ _dataloader.logger,
|
|
|
+ _model.logger,
|
|
|
+ dist_utils.logger,
|
|
|
+ ]
|
|
|
|
|
|
|
|
|
def set_file_output_for_loggers(log_filename: str) -> None:
|
|
@@ -91,7 +111,9 @@ def main() -> None:
|
|
|
dist_utils.init_distributed(get_loggers())
|
|
|
is_master = dist_utils.is_main_process()
|
|
|
with open(args.params, "r") as fp_in:
|
|
|
- parameters = WorkflowParams.deserialize(yaml.load(fp_in, Loader=yaml.FullLoader))
|
|
|
+ parameters = WorkflowParams.deserialize(
|
|
|
+ yaml.load(fp_in, Loader=yaml.FullLoader)
|
|
|
+ )
|
|
|
ts = str(int(time.time()))
|
|
|
work_dir = args.wd
|
|
|
checkpoint_dir = os.path.join(work_dir, "checkpoints")
|
|
@@ -108,7 +130,9 @@ def main() -> None:
|
|
|
logger.info(f"Set logging to {log_path}")
|
|
|
set_file_output_for_loggers(log_path)
|
|
|
try:
|
|
|
- run_training(parameters=parameters, work_dir=work_dir, checkpoint_dir=checkpoint_dir)
|
|
|
+ run_training(
|
|
|
+ parameters=parameters, work_dir=work_dir, checkpoint_dir=checkpoint_dir
|
|
|
+ )
|
|
|
except Exception:
|
|
|
# make sure that the stack tracke will be logged to log files
|
|
|
logger.exception("Training failed")
|