Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom client/scheduler MPI rank placement #110

Merged
merged 35 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
91bc980
Generalizing starting coroutines in CLI
kmpaul Sep 27, 2023
a449b7c
Ignore scratch space
kmpaul Sep 27, 2023
fcee51e
scheduler_rank and exclusive_workers options with tests
kmpaul Sep 27, 2023
055a95d
allow 1-rank clusters
kmpaul Sep 27, 2023
ff32000
Correction to min world size calculation
kmpaul Sep 29, 2023
10d54c0
Rename module for clarity
kmpaul Sep 29, 2023
aa00b98
Create execute function
kmpaul Sep 29, 2023
483d849
NOQA on unused imports
kmpaul Oct 11, 2023
b6f2841
Set worker type if deprecated "--no-nanny" option set
kmpaul Oct 11, 2023
c850dd9
Set worker type before raise
kmpaul Oct 11, 2023
376852c
Import from dask not distributed
kmpaul Oct 11, 2023
74cf544
Update versioneer script to new Python
kmpaul Oct 11, 2023
6592545
Temporary fix for python 3.12 changes
kmpaul Oct 11, 2023
6219830
Custom rank placement logic
kmpaul Oct 11, 2023
3cdd7ba
Move no_exit test into main initialize test and rename
kmpaul Oct 12, 2023
9341961
Renaming for better clarity
kmpaul Oct 12, 2023
c61ec65
Add execute tests
kmpaul Oct 12, 2023
51f21e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2023
b1fbbf5
Set python version in readthedocs env
kmpaul Oct 12, 2023
82038c1
Try fixing python version for readthedocs build
kmpaul Oct 12, 2023
cc7bf1b
Merge branch 'rank-placement' of https://github.com/dask/dask-mpi int…
kmpaul Oct 12, 2023
2a05cca
Revert
kmpaul Oct 12, 2023
b838307
Rename to match test name
kmpaul Oct 12, 2023
1c43320
Rename test / add no_exit test for execute
kmpaul Oct 12, 2023
b67d5a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2023
5ef8ca4
Possible to not supply a function
kmpaul Oct 13, 2023
4222801
Reorder cli options
kmpaul Oct 13, 2023
5734764
Add more execute options
kmpaul Oct 13, 2023
646f701
Merge branch 'rank-placement' of https://github.com/dask/dask-mpi int…
kmpaul Oct 13, 2023
3ebc321
move send_close_signal to execute
kmpaul Oct 13, 2023
a35c261
send_close_signal has moved
kmpaul Oct 13, 2023
d5ee1e8
Deprecate initialize, now that execute does everything
kmpaul Oct 13, 2023
bb352ca
Revert initialize deprecation warning for now
kmpaul Oct 13, 2023
7b567d6
Merge branch 'main' of https://github.com/dask/dask-mpi into rank-pla…
kmpaul Jul 2, 2024
274f195
Attempt RTF fix
kmpaul Jul 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ global.lock
purge.lock
/temp/
/dask-worker-space/
/dask-scratch-space/

# VSCode files
.vscode/
Expand Down
7 changes: 4 additions & 3 deletions dask_mpi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import _version
from .core import initialize, send_close_signal
from .exceptions import WorldTooSmallException
from ._version import get_versions
from .exceptions import WorldTooSmallException # noqa
from .execute import execute, send_close_signal # noqa
from .initialize import initialize # noqa

__version__ = _version.get_versions()["version"]
104 changes: 63 additions & 41 deletions dask_mpi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
type=int,
help="Specify scheduler port number. Defaults to random.",
)
@click.option(
"--scheduler-rank",
default=0,
type=int,
help="The MPI rank on which the scheduler will launch. Defaults to 0.",
)
@click.option(
"--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'"
)
Expand Down Expand Up @@ -56,6 +62,14 @@
default=True,
help="Start workers in nanny process for management (deprecated use --worker-class instead)",
)
@click.option(
"--exclusive-workers/--inclusive-workers",
default=True,
help=(
"Whether to force workers to run on unoccupied MPI ranks. If false, "
"then a worker will be launched on the same rank as the scheduler."
),
)
@click.option(
"--worker-class",
type=str,
Expand Down Expand Up @@ -90,27 +104,30 @@
def main(
scheduler_address,
scheduler_file,
scheduler_port,
scheduler_rank,
interface,
protocol,
nthreads,
local_directory,
memory_limit,
local_directory,
scheduler,
dashboard,
dashboard_address,
nanny,
exclusive_workers,
worker_class,
worker_options,
scheduler_port,
protocol,
name,
):
comm = MPI.COMM_WORLD

world_size = comm.Get_size()
if scheduler and world_size < 2:
min_world_size = 1 + scheduler * max(scheduler_rank, exclusive_workers)
if world_size < min_world_size:
raise WorldTooSmallException(
f"Not enough MPI ranks to start cluster, found {world_size}, "
"needs at least 2, one each for the scheduler and a worker."
f"Not enough MPI ranks to start cluster with exclusive_workers={exclusive_workers} and "
f"scheduler_rank={scheduler_rank}, found {world_size} MPI ranks but needs {min_world_size}."
)

rank = comm.Get_rank()
Expand All @@ -120,47 +137,52 @@ def main(
except TypeError:
worker_options = {}

if rank == 0 and scheduler:
async def run_worker():
WorkerType = import_term(worker_class)
if not nanny:
WorkerType = Worker
raise DeprecationWarning(
"Option --no-nanny is deprectaed, use --worker-class instead"
)
opts = {
"interface": interface,
"protocol": protocol,
"nthreads": nthreads,
"memory_limit": memory_limit,
"local_directory": local_directory,
"name": f"{name}-{rank}",
"scheduler_file": scheduler_file,
**worker_options,
}
if scheduler_address:
opts["scheduler_ip"] = scheduler_address

async def run_scheduler():
async with Scheduler(
interface=interface,
protocol=protocol,
dashboard=dashboard,
dashboard_address=dashboard_address,
scheduler_file=scheduler_file,
port=scheduler_port,
) as s:
comm.Barrier()
await s.finished()
async with WorkerType(**opts) as worker:
await worker.finished()

asyncio.get_event_loop().run_until_complete(run_scheduler())
async def run_scheduler(launch_worker=False):
async with Scheduler(
interface=interface,
protocol=protocol,
dashboard=dashboard,
dashboard_address=dashboard_address,
scheduler_file=scheduler_file,
port=scheduler_port,
) as scheduler:
comm.Barrier()

if launch_worker:
asyncio.get_event_loop().create_task(run_worker())

await scheduler.finished()

if rank == scheduler_rank and scheduler:
asyncio.get_event_loop().run_until_complete(
run_scheduler(launch_worker=not exclusive_workers)
)
else:
comm.Barrier()

async def run_worker():
WorkerType = import_term(worker_class)
if not nanny:
raise DeprecationWarning(
"Option --no-nanny is deprectaed, use --worker-class instead"
)
WorkerType = Worker
opts = {
"interface": interface,
"protocol": protocol,
"nthreads": nthreads,
"memory_limit": memory_limit,
"local_directory": local_directory,
"name": f"{name}-{rank}",
"scheduler_file": scheduler_file,
**worker_options,
}
if scheduler_address:
opts["scheduler_ip"] = scheduler_address
async with WorkerType(**opts) as worker:
await worker.finished()

asyncio.get_event_loop().run_until_complete(run_worker())


Expand Down
210 changes: 210 additions & 0 deletions dask_mpi/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import asyncio
import threading

import dask
from distributed import Client, Nanny, Scheduler
from distributed.utils import import_term

from .exceptions import WorldTooSmallException


def execute(
client_function=None,
client_args=(),
client_kwargs=None,
client_rank=1,
scheduler=True,
scheduler_rank=0,
scheduler_address=None,
scheduler_port=None,
scheduler_file=None,
interface=None,
nthreads=1,
local_directory="",
memory_limit="auto",
nanny=False,
dashboard=True,
dashboard_address=":8787",
protocol=None,
exclusive_workers=True,
worker_class="distributed.Worker",
worker_options=None,
worker_name=None,
comm=None,
):
"""
Execute a function on a given MPI rank with a Dask cluster launched using mpi4py
Using mpi4py, MPI rank 0 launches the Scheduler, MPI rank 1 passes through to the
client script, and all other MPI ranks launch workers. All MPI ranks other than
MPI rank 1 block while their event loops run.
In normal operation these ranks exit once rank 1 ends. If exit=False is set they
instead return an bool indicating whether they are the client and should execute
more client code, or a worker/scheduler who should not. In this case the user is
responsible for the client calling send_close_signal when work is complete, and
checking the returned value to choose further actions.
Parameters
----------
func : callable
A function containing Dask client code to execute with a Dask cluster. If
func it not callable, then no client code will be executed.
args : list
Arguments to the client function
client_rank : int
The MPI rank on which to run func.
scheduler_rank : int
The MPI rank on which to run the Dask scheduler
scheduler_address : str
IP Address of the scheduler, used if scheduler is not launched
scheduler_port : int
Specify scheduler port number. Defaults to random.
scheduler_file : str
Filename to JSON encoded scheduler information.
interface : str
Network interface like 'eth0' or 'ib0'
nthreads : int
Number of threads per worker
local_directory : str
Directory to place worker files
memory_limit : int, float, or 'auto'
Number of bytes before spilling data to disk. This can be an
integer (nbytes), float (fraction of total memory), or 'auto'.
nanny : bool
Start workers in nanny process for management (deprecated, use worker_class instead)
dashboard : bool
Enable Bokeh visual diagnostics
dashboard_address : str
Bokeh port for visual diagnostics
protocol : str
Protocol like 'inproc' or 'tcp'
exclusive_workers : bool
Whether to only run Dask workers on their own MPI ranks
worker_class : str
Class to use when creating workers
worker_options : dict
Options to pass to workers
worker_name : str
Prefix for name given to workers. If defined, each worker will be named
'{worker_name}-{rank}'. Otherwise, the name of each worker is just '{rank}'.
comm : mpi4py.MPI.Intracomm
Optional MPI communicator to use instead of COMM_WORLD
kwargs : dict
Keyword arguments to the client function
"""
if comm is None:
from mpi4py import MPI

comm = MPI.COMM_WORLD

world_size = comm.Get_size()
min_world_size = 1 + max(client_rank, scheduler_rank, exclusive_workers)
if world_size < min_world_size:
raise WorldTooSmallException(
f"Not enough MPI ranks to start cluster with exclusive_workers={exclusive_workers} and "
f"scheduler_rank={scheduler_rank}, found {world_size} MPI ranks but needs {min_world_size}."
)

rank = comm.Get_rank()

if not worker_options:
worker_options = {}

async def run_client():
def wrapped_function(*args, **kwargs):
client_function(*args, **kwargs)
send_close_signal()

threading.Thread(
target=wrapped_function, args=client_args, kwargs=client_kwargs
).start()

async def run_worker(with_client=False):
WorkerType = import_term(worker_class)
if nanny:
WorkerType = Nanny
raise DeprecationWarning(
"Option nanny=True is deprectaed, use worker_class='distributed.Nanny' instead"
)
opts = {
"interface": interface,
"protocol": protocol,
"nthreads": nthreads,
"memory_limit": memory_limit,
"local_directory": local_directory,
"name": rank if not worker_name else f"{worker_name}-{rank}",
**worker_options,
}
if not scheduler and scheduler_address:
opts["scheduler_ip"] = scheduler_address
async with WorkerType(**opts) as worker:
if with_client:
asyncio.get_event_loop().create_task(run_client())

await worker.finished()

async def run_scheduler(with_worker=False, with_client=False):
async with Scheduler(
interface=interface,
protocol=protocol,
dashboard=dashboard,
dashboard_address=dashboard_address,
scheduler_file=scheduler_file,
port=scheduler_port,
) as scheduler:
dask.config.set(scheduler_address=scheduler.address)
comm.bcast(scheduler.address, root=scheduler_rank)
comm.Barrier()

if with_worker:
asyncio.get_event_loop().create_task(
run_worker(with_client=with_client)
)

elif with_client:
asyncio.get_event_loop().create_task(run_client())

await scheduler.finished()

with_scheduler = scheduler and (rank == scheduler_rank)
with_client = callable(client_function) and (rank == client_rank)

if with_scheduler:
run_coro = run_scheduler(
with_worker=not exclusive_workers,
with_client=with_client,
)

else:
if scheduler:
scheduler_address = comm.bcast(None, root=scheduler_rank)
elif scheduler_address is None:
raise ValueError(
"Must provide scheduler_address if executing with scheduler=False"
)
dask.config.set(scheduler_address=scheduler_address)
comm.Barrier()

if with_client and exclusive_workers:
run_coro = run_client()
else:
run_coro = run_worker(with_client=with_client)

asyncio.get_event_loop().run_until_complete(run_coro)


def send_close_signal():
"""
The client can call this function to explicitly stop
the event loop.
This is not needed in normal usage, where it is run
automatically when the client code exits python.
You only need to call this manually when using exit=False
in initialize.
"""

with Client() as c:
c.shutdown()
Loading
Loading