diff --git a/dask_mpi/cli.py b/dask_mpi/cli.py index 039dd10..1538d4b 100644 --- a/dask_mpi/cli.py +++ b/dask_mpi/cli.py @@ -23,6 +23,12 @@ type=int, help="Specify scheduler port number. Defaults to random.", ) +@click.option( + "--scheduler-rank", + default=0, + type=int, + help="The MPI rank on which the scheduler will launch. Defaults to 0.", +) @click.option( "--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'" ) @@ -56,6 +62,14 @@ default=True, help="Start workers in nanny process for management (deprecated use --worker-class instead)", ) +@click.option( + "--exclusive-workers", + default=True, + help=( + "Whether to force workers to run on unoccupied MPI ranks. If false, " + "then a worker will be launched on the same rank as the scheduler." + ), +) @click.option( "--worker-class", type=str, @@ -83,6 +97,7 @@ def main( scheduler_address, scheduler_file, + scheduler_rank, interface, nthreads, local_directory, @@ -90,6 +105,7 @@ def main( scheduler, dashboard_address, nanny, + exclusive_workers, worker_class, worker_options, scheduler_port, @@ -112,9 +128,9 @@ def main( except TypeError: worker_options = {} - if rank == 0 and scheduler: + if rank == scheduler_rank and scheduler: - async def run_scheduler(): + async def run_func(): async with Scheduler( interface=interface, protocol=protocol, @@ -125,18 +141,16 @@ async def run_scheduler(): comm.Barrier() await s.finished() - asyncio.get_event_loop().run_until_complete(run_scheduler()) - else: - comm.Barrier() - async def run_worker(): + async def run_func(): + comm.Barrier() + WorkerType = import_term(worker_class) if not nanny: raise DeprecationWarning( "Option --no-nanny is deprectaed, use --worker-class instead" ) - WorkerType = Worker opts = { "interface": interface, "protocol": protocol, @@ -149,10 +163,11 @@ async def run_worker(): } if scheduler_address: opts["scheduler_ip"] = scheduler_address + async with WorkerType(**opts) as worker: await worker.finished() - asyncio.get_event_loop().run_until_complete(run_worker()) + asyncio.get_event_loop().run_until_complete(run_func()) if __name__ == "__main__":