12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import os
- from datetime import timedelta
- from typing import List
- import torch
- import torch.distributed as dist
- import torch.multiprocessing
- logger = logging.getLogger(__name__)
- def is_dist_initialized() -> bool:
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
- def get_rank() -> int:
- if not is_dist_initialized():
- return 0
- return dist.get_rank()
- def get_local_rank() -> int:
- if not is_dist_initialized():
- return 0
- return int(os.environ["LOCAL_RANK"])
- def get_world_size() -> int:
- if not is_dist_initialized():
- return 1
- return dist.get_world_size()
- def is_main_process() -> bool:
- return get_rank() == 0
- def init_distributed(loggers: List[logging.Logger]) -> None:
- """Initializes the distributed backend"""
- torch.multiprocessing.set_start_method("spawn")
- if "RANK" not in os.environ:
- logger.error(
- "Cannot init disributed context, as environment varaibles are not set."
- )
- return
- rank = int(os.environ["RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
- local_rank = int(os.environ["LOCAL_RANK"])
- logger.info(
- f"Rank={rank} local rank={local_rank}, world_size={world_size}, is_master={rank == 0}"
- )
- dist.init_process_group(
- backend="nccl",
- init_method="env://",
- world_size=world_size,
- rank=rank,
- timeout=timedelta(seconds=180),
- )
- logger.info(f"Setting cuda:{local_rank} as main device")
- if not is_main_process():
- for to_mute in loggers:
- to_mute.setLevel(logging.ERROR)
- torch.cuda.set_device(local_rank)
- dist.barrier()
|