diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index eefb71fc4..ea0bbb3c7 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -175,7 +175,11 @@ def _choose_port(id): def auto_ray_cluster( - address: Optional[str] = None, namespace: Optional[str] = "levanter", start_workers: bool = True, **kwargs + address: Optional[str] = None, + namespace: Optional[str] = "levanter", + start_workers: bool = True, + fail_if_cluster_already_initialized: bool = False, + **kwargs, ): """Initializes ray, automatically discovering the address if it is not provided. Currently supports slurm and TPU. @@ -220,11 +224,10 @@ def _munge_address_port(address: str): # Explicitly setting the number of CPUs on ray init stops init errors num_cpus = logical_cpu_core_count() - # it used to be that if we were coordinator, we were also process 0 - # this is no longer the case, so instead we need to check if we are the coordinator - # and if so, start the head - if _is_local_leader(): + # it used to be that if we were coordinator, we were also process 0 + # this is no longer the case, so instead we need to check if we are the coordinator + # and if so, start the head if _is_this_machine(host): logger.info(f"Starting ray head on port {ray_port}. We are process the coordinator {host}.") logger.info(f"Starting ray head with num_cpus set to {num_cpus}.") @@ -232,7 +235,19 @@ def _munge_address_port(address: str): f"ray start --head --port {ray_port} --num-cpus {num_cpus} --dashboard-host=0.0.0.0" ) if ret != 0: - raise RuntimeError(f"Failed to start ray head with exit code {ret}") + if not fail_if_cluster_already_initialized: + # see if we can connect to the head + logger.warning( + f"Failed to start ray head with exit code {ret}. Checking if we can connect to" + " the head..." + ) + ret = os.system("ray status") + if ret != 0: + raise RuntimeError(f"Failed to start ray head with exit code {ret}") + else: + logger.info(f"Ray head already running on port {ray_port}. Connecting to it.") + else: + raise RuntimeError(f"Failed to start ray head with exit code {ret}") else: logger.info(f"Successfully started ray head on port {ray_port}.")