Skip to content

Commit

Permalink
Generalizing starting coroutines in CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
kmpaul committed Sep 27, 2023
1 parent 3a7f045 commit 91bc980
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions dask_mpi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
type=int,
help="Specify scheduler port number. Defaults to random.",
)
@click.option(
"--scheduler-rank",
default=0,
type=int,
help="The MPI rank on which the scheduler will launch. Defaults to 0.",
)
@click.option(
"--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'"
)
Expand Down Expand Up @@ -56,6 +62,14 @@
default=True,
help="Start workers in nanny process for management (deprecated use --worker-class instead)",
)
@click.option(
"--exclusive-workers",
default=True,
help=(
"Whether to force workers to run on unoccupied MPI ranks. If false, "
"then a worker will be launched on the same rank as the scheduler."
),
)
@click.option(
"--worker-class",
type=str,
Expand Down Expand Up @@ -83,13 +97,15 @@
def main(
scheduler_address,
scheduler_file,
scheduler_rank,
interface,
nthreads,
local_directory,
memory_limit,
scheduler,
dashboard_address,
nanny,
exclusive_workers,
worker_class,
worker_options,
scheduler_port,
Expand All @@ -112,9 +128,9 @@ def main(
except TypeError:
worker_options = {}

if rank == 0 and scheduler:
if rank == scheduler_rank and scheduler:

async def run_scheduler():
async def run_func():
async with Scheduler(
interface=interface,
protocol=protocol,
Expand All @@ -125,18 +141,16 @@ async def run_scheduler():
comm.Barrier()
await s.finished()

asyncio.get_event_loop().run_until_complete(run_scheduler())

else:
comm.Barrier()

async def run_worker():
async def run_func():
comm.Barrier()

WorkerType = import_term(worker_class)
if not nanny:
raise DeprecationWarning(
"Option --no-nanny is deprectaed, use --worker-class instead"
)
WorkerType = Worker
opts = {
"interface": interface,
"protocol": protocol,
Expand All @@ -149,10 +163,11 @@ async def run_worker():
}
if scheduler_address:
opts["scheduler_ip"] = scheduler_address

async with WorkerType(**opts) as worker:
await worker.finished()

asyncio.get_event_loop().run_until_complete(run_worker())
asyncio.get_event_loop().run_until_complete(run_func())


if __name__ == "__main__":
Expand Down

0 comments on commit 91bc980

Please sign in to comment.