| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 | # Copyright (c) Meta Platforms, Inc. and affiliates# All rights reserved.## This source code is licensed under the license found in the# LICENSE file in the root directory of this source tree.import loggingimport osfrom datetime import timedeltafrom typing import Listimport torchimport torch.distributed as distimport torch.multiprocessinglogger = logging.getLogger(__name__)def is_dist_initialized() -> bool:    if not dist.is_available():        return False    if not dist.is_initialized():        return False    return Truedef get_rank() -> int:    if not is_dist_initialized():        return 0    return dist.get_rank()def get_local_rank() -> int:    if not is_dist_initialized():        return 0    return int(os.environ["LOCAL_RANK"])def get_world_size() -> int:    if not is_dist_initialized():        return 1    return dist.get_world_size()def is_main_process() -> bool:    return get_rank() == 0def init_distributed(loggers: List[logging.Logger]) -> None:    """Initializes the distributed backend"""    torch.multiprocessing.set_start_method("spawn")    if "RANK" not in os.environ:        logger.error(            "Cannot init disributed context, as environment varaibles are not set."        )        return    rank = int(os.environ["RANK"])    world_size = int(os.environ["WORLD_SIZE"])    local_rank = int(os.environ["LOCAL_RANK"])    logger.info(        f"Rank={rank} local rank={local_rank}, world_size={world_size}, is_master={rank == 0}"    )    dist.init_process_group(        backend="nccl",        init_method="env://",        world_size=world_size,        rank=rank,        timeout=timedelta(seconds=180),    )    logger.info(f"Setting cuda:{local_rank} as main device")    if not is_main_process():        for to_mute in loggers:            to_mute.setLevel(logging.ERROR)    torch.cuda.set_device(local_rank)    dist.barrier()
 |