Skip to content

Commit

Permalink
Create execute function
Browse files Browse the repository at this point in the history
  • Loading branch information
kmpaul committed Sep 29, 2023
1 parent 10d54c0 commit aa00b98
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 1 deletion.
3 changes: 2 additions & 1 deletion dask_mpi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._version import get_versions
from .core import initialize, send_close_signal
from .execute import execute
from .initialize import initialize, send_close_signal
from .exceptions import WorldTooSmallException

__version__ = get_versions()["version"]
Expand Down
151 changes: 151 additions & 0 deletions dask_mpi/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import asyncio

import dask
from distributed import Nanny, Scheduler
from distributed.utils import import_term

from .initialize import send_close_signal
from .exceptions import WorldTooSmallException


def execute(
func,
*args,
client_rank=1,
scheduler_rank=0,
interface=None,
nthreads=1,
local_directory="",
memory_limit="auto",
nanny=False,
dashboard=True,
dashboard_address=":8787",
protocol=None,
exclusive_workers=True,
worker_class="distributed.Worker",
worker_options=None,
comm=None,
**kwargs,
):
"""
Execute a function on a given MPI rank with a Dask cluster launched using mpi4py
Using mpi4py, MPI rank 0 launches the Scheduler, MPI rank 1 passes through to the
client script, and all other MPI ranks launch workers. All MPI ranks other than
MPI rank 1 block while their event loops run.
In normal operation these ranks exit once rank 1 ends. If exit=False is set they
instead return an bool indicating whether they are the client and should execute
more client code, or a worker/scheduler who should not. In this case the user is
responsible for the client calling send_close_signal when work is complete, and
checking the returned value to choose further actions.
Parameters
----------
func : callable
A function containing Dask client code to execute with a Dask cluster
args : list
Arguments to func
client_rank : int
The MPI rank on which to run func
scheduler_rank : int
The MPI rank on which to run the Dask scheduler
interface : str
Network interface like 'eth0' or 'ib0'
nthreads : int
Number of threads per worker
local_directory : str
Directory to place worker files
memory_limit : int, float, or 'auto'
Number of bytes before spilling data to disk. This can be an
integer (nbytes), float (fraction of total memory), or 'auto'.
nanny : bool
Start workers in nanny process for management (deprecated, use worker_class instead)
dashboard : bool
Enable Bokeh visual diagnostics
dashboard_address : str
Bokeh port for visual diagnostics
protocol : str
Protocol like 'inproc' or 'tcp'
exclusive_workers : bool
Whether to only run Dask workers on their own MPI ranks
worker_class : str
Class to use when creating workers
worker_options : dict
Options to pass to workers
comm: mpi4py.MPI.Intracomm
Optional MPI communicator to use instead of COMM_WORLD
Returns
-------
ret : Any
If the MPI rank equals client_rank, then the return value of the executed function.
Otherwise, returns None.
"""
if comm is None:
from mpi4py import MPI

comm = MPI.COMM_WORLD

world_size = comm.Get_size()
min_world_size = 1 + max(client_rank, scheduler_rank, exclusive_workers)
if world_size < min_world_size:
raise WorldTooSmallException(
f"Not enough MPI ranks to start cluster with exclusive_workers={exclusive_workers} and "
f"scheduler_rank={scheduler_rank}, found {world_size} MPI ranks but needs {min_world_size}."
)

rank = comm.Get_rank()

if not worker_options:
worker_options = {}

async def run_worker():
WorkerType = import_term(worker_class)
if nanny:
raise DeprecationWarning(
"Option nanny=True is deprectaed, use worker_class='distributed.Nanny' instead"
)
WorkerType = Nanny
opts = {
"interface": interface,
"protocol": protocol,
"nthreads": nthreads,
"memory_limit": memory_limit,
"local_directory": local_directory,
"name": rank,
**worker_options,
}
async with WorkerType(**opts) as worker:
await worker.finished()

async def run_scheduler(launch_worker=False):
async with Scheduler(
interface=interface,
protocol=protocol,
dashboard=dashboard,
dashboard_address=dashboard_address,
) as scheduler:
comm.bcast(scheduler.address, root=0)
comm.Barrier()

if launch_worker:
asyncio.create_task(run_worker())

await scheduler.finished()

if rank == scheduler_rank:
asyncio.get_event_loop().run_until_complete(run_scheduler())

else:
scheduler_address = comm.bcast(None, root=scheduler_rank)
dask.config.set(scheduler_address=scheduler_address)
comm.Barrier()

if rank == client_rank:
ret = func(*args, **kwargs)
send_close_signal()
return ret

else:
asyncio.get_event_loop().run_until_complete(run_worker())
21 changes: 21 additions & 0 deletions dask_mpi/tests/execute_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from time import sleep

from distributed import Client
from distributed.metrics import time

from dask_mpi import execute


def client_func():
with Client() as c:
start = time()
while len(c.scheduler_info()["workers"]) != 2:
assert time() < start + 10
sleep(0.2)

assert c.submit(lambda x: x + 1, 10).result() == 11
assert c.submit(lambda x: x + 1, 20, workers=2).result() == 21


if __name__ == "__main__":
execute(client_func)
32 changes: 32 additions & 0 deletions dask_mpi/tests/test_execute_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import absolute_import, division, print_function

import os
import subprocess
import sys

import pytest

pytest.importorskip("mpi4py")


def test_basic(mpirun):
script_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "execute_basic.py"
)

p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file])

p.communicate()
assert p.returncode == 0


def test_small_world(mpirun):
script_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "execute_basic.py"
)

# Set too few processes to start cluster
p = subprocess.Popen(mpirun + ["-np", "1", sys.executable, script_file])

p.communicate()
assert p.returncode != 0

0 comments on commit aa00b98

Please sign in to comment.