Skip to content

Commit

Permalink
scheduler_rank and exclusive_workers options with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kmpaul committed Sep 27, 2023
1 parent a449b7c commit fcee51e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 54 deletions.
79 changes: 42 additions & 37 deletions dask_mpi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
help="Start workers in nanny process for management (deprecated use --worker-class instead)",
)
@click.option(
"--exclusive-workers",
"--exclusive-workers/--inclusive-workers",
default=True,
help=(
"Whether to force workers to run on unoccupied MPI ranks. If false, "
Expand Down Expand Up @@ -128,46 +128,51 @@ def main(
except TypeError:
worker_options = {}

if rank == scheduler_rank and scheduler:
async def run_worker():
WorkerType = import_term(worker_class)
if not nanny:
raise DeprecationWarning(
"Option --no-nanny is deprectaed, use --worker-class instead"
)
opts = {
"interface": interface,
"protocol": protocol,
"nthreads": nthreads,
"memory_limit": memory_limit,
"local_directory": local_directory,
"name": f"{name}-{rank}",
"scheduler_file": scheduler_file,
**worker_options,
}
if scheduler_address:
opts["scheduler_ip"] = scheduler_address

async with WorkerType(**opts) as worker:
await worker.finished()

async def run_scheduler(launch_worker=False):
async with Scheduler(
interface=interface,
protocol=protocol,
dashboard_address=dashboard_address,
scheduler_file=scheduler_file,
port=scheduler_port,
) as scheduler:
comm.Barrier()

async def run_func():
async with Scheduler(
interface=interface,
protocol=protocol,
dashboard_address=dashboard_address,
scheduler_file=scheduler_file,
port=scheduler_port,
) as s:
comm.Barrier()
await s.finished()
if launch_worker:
asyncio.get_event_loop().create_task(run_worker())

else:
await scheduler.finished()

async def run_func():
comm.Barrier()
if rank == scheduler_rank and scheduler:
asyncio.get_event_loop().run_until_complete(
run_scheduler(launch_worker=not exclusive_workers)
)
else:
comm.Barrier()

WorkerType = import_term(worker_class)
if not nanny:
raise DeprecationWarning(
"Option --no-nanny is deprectaed, use --worker-class instead"
)
opts = {
"interface": interface,
"protocol": protocol,
"nthreads": nthreads,
"memory_limit": memory_limit,
"local_directory": local_directory,
"name": f"{name}-{rank}",
"scheduler_file": scheduler_file,
**worker_options,
}
if scheduler_address:
opts["scheduler_ip"] = scheduler_address

async with WorkerType(**opts) as worker:
await worker.finished()

asyncio.get_event_loop().run_until_complete(run_func())
asyncio.get_event_loop().run_until_complete(run_worker())


if __name__ == "__main__":
Expand Down
67 changes: 50 additions & 17 deletions dask_mpi/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ def test_basic(loop, worker_class, mpirun):
assert c.submit(lambda x: x + 1, 10).result() == 11


def test_inclusive_workers(loop, mpirun):
with tmpfile(extension="json") as fn:
cmd = mpirun + [
"-np",
"4",
"dask-mpi",
"--scheduler-file",
fn,
"--inclusive-workers",
]

with popen(cmd):
with Client(scheduler_file=fn) as client:
start = time()
while len(client.scheduler_info()["workers"]) < 4:
assert time() < start + 10
sleep(0.1)

assert client.submit(lambda x: x + 1, 10).result() == 11


def test_small_world(mpirun):
with tmpfile(extension="json") as fn:
# Set too few processes to start cluster
Expand Down Expand Up @@ -98,6 +119,35 @@ def test_no_scheduler(loop, mpirun):
sleep(0.2)


def test_scheduler_rank(loop, mpirun):
with tmpfile(extension="json") as fn:
cmd = mpirun + [
"-np",
"2",
"dask-mpi",
"--scheduler-file",
fn,
"--exclusive-workers",
"--scheduler-rank",
"1",
]

with popen(cmd, stdin=FNULL):
with Client(scheduler_file=fn) as client:
start = time()
while len(client.scheduler_info()["workers"]) < 1:
assert time() < start + 10
sleep(0.2)

worker_infos = client.scheduler_info()["workers"]
assert len(worker_infos) == 1

worker_info = next(iter(worker_infos.values()))
assert worker_info["name"].rsplit("-")[-1] == "0"

assert client.submit(lambda x: x + 1, 10).result() == 11


@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"])
def test_non_default_ports(loop, nanny, mpirun):
with tmpfile(extension="json") as fn:
Expand Down Expand Up @@ -150,23 +200,6 @@ def test_dashboard(loop, mpirun):
requests.get("http://localhost:59583/status/")


@pytest.mark.skip(reason="Should we expose this option?")
def test_bokeh_worker(loop, mpirun):
with tmpfile(extension="json") as fn:
cmd = mpirun + [
"-np",
"2",
"dask-mpi",
"--scheduler-file",
fn,
"--bokeh-worker-port",
"59584",
]

with popen(cmd, stdin=FNULL):
check_port_okay(59584)


def tmpfile_static(extension="", dir=None):
"""
utility function for test_stale_sched test
Expand Down

0 comments on commit fcee51e

Please sign in to comment.