-
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
206 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import asyncio | ||
|
||
import dask | ||
from distributed import Nanny, Scheduler | ||
from distributed.utils import import_term | ||
|
||
from .initialize import send_close_signal | ||
from .exceptions import WorldTooSmallException | ||
|
||
|
||
def execute( | ||
func, | ||
*args, | ||
client_rank=1, | ||
scheduler_rank=0, | ||
interface=None, | ||
nthreads=1, | ||
local_directory="", | ||
memory_limit="auto", | ||
nanny=False, | ||
dashboard=True, | ||
dashboard_address=":8787", | ||
protocol=None, | ||
exclusive_workers=True, | ||
worker_class="distributed.Worker", | ||
worker_options=None, | ||
comm=None, | ||
**kwargs, | ||
): | ||
""" | ||
Execute a function on a given MPI rank with a Dask cluster launched using mpi4py | ||
Using mpi4py, MPI rank 0 launches the Scheduler, MPI rank 1 passes through to the | ||
client script, and all other MPI ranks launch workers. All MPI ranks other than | ||
MPI rank 1 block while their event loops run. | ||
In normal operation these ranks exit once rank 1 ends. If exit=False is set they | ||
instead return an bool indicating whether they are the client and should execute | ||
more client code, or a worker/scheduler who should not. In this case the user is | ||
responsible for the client calling send_close_signal when work is complete, and | ||
checking the returned value to choose further actions. | ||
Parameters | ||
---------- | ||
func : callable | ||
A function containing Dask client code to execute with a Dask cluster | ||
args : list | ||
Arguments to func | ||
client_rank : int | ||
The MPI rank on which to run func | ||
scheduler_rank : int | ||
The MPI rank on which to run the Dask scheduler | ||
interface : str | ||
Network interface like 'eth0' or 'ib0' | ||
nthreads : int | ||
Number of threads per worker | ||
local_directory : str | ||
Directory to place worker files | ||
memory_limit : int, float, or 'auto' | ||
Number of bytes before spilling data to disk. This can be an | ||
integer (nbytes), float (fraction of total memory), or 'auto'. | ||
nanny : bool | ||
Start workers in nanny process for management (deprecated, use worker_class instead) | ||
dashboard : bool | ||
Enable Bokeh visual diagnostics | ||
dashboard_address : str | ||
Bokeh port for visual diagnostics | ||
protocol : str | ||
Protocol like 'inproc' or 'tcp' | ||
exclusive_workers : bool | ||
Whether to only run Dask workers on their own MPI ranks | ||
worker_class : str | ||
Class to use when creating workers | ||
worker_options : dict | ||
Options to pass to workers | ||
comm: mpi4py.MPI.Intracomm | ||
Optional MPI communicator to use instead of COMM_WORLD | ||
Returns | ||
------- | ||
ret : Any | ||
If the MPI rank equals client_rank, then the return value of the executed function. | ||
Otherwise, returns None. | ||
""" | ||
if comm is None: | ||
from mpi4py import MPI | ||
|
||
comm = MPI.COMM_WORLD | ||
|
||
world_size = comm.Get_size() | ||
min_world_size = 1 + max(client_rank, scheduler_rank, exclusive_workers) | ||
if world_size < min_world_size: | ||
raise WorldTooSmallException( | ||
f"Not enough MPI ranks to start cluster with exclusive_workers={exclusive_workers} and " | ||
f"scheduler_rank={scheduler_rank}, found {world_size} MPI ranks but needs {min_world_size}." | ||
) | ||
|
||
rank = comm.Get_rank() | ||
|
||
if not worker_options: | ||
worker_options = {} | ||
|
||
async def run_worker(): | ||
WorkerType = import_term(worker_class) | ||
if nanny: | ||
raise DeprecationWarning( | ||
"Option nanny=True is deprectaed, use worker_class='distributed.Nanny' instead" | ||
) | ||
WorkerType = Nanny | ||
opts = { | ||
"interface": interface, | ||
"protocol": protocol, | ||
"nthreads": nthreads, | ||
"memory_limit": memory_limit, | ||
"local_directory": local_directory, | ||
"name": rank, | ||
**worker_options, | ||
} | ||
async with WorkerType(**opts) as worker: | ||
await worker.finished() | ||
|
||
async def run_scheduler(launch_worker=False): | ||
async with Scheduler( | ||
interface=interface, | ||
protocol=protocol, | ||
dashboard=dashboard, | ||
dashboard_address=dashboard_address, | ||
) as scheduler: | ||
comm.bcast(scheduler.address, root=0) | ||
comm.Barrier() | ||
|
||
if launch_worker: | ||
asyncio.create_task(run_worker()) | ||
|
||
await scheduler.finished() | ||
|
||
if rank == scheduler_rank: | ||
asyncio.get_event_loop().run_until_complete(run_scheduler()) | ||
|
||
else: | ||
scheduler_address = comm.bcast(None, root=scheduler_rank) | ||
dask.config.set(scheduler_address=scheduler_address) | ||
comm.Barrier() | ||
|
||
if rank == client_rank: | ||
ret = func(*args, **kwargs) | ||
send_close_signal() | ||
return ret | ||
|
||
else: | ||
asyncio.get_event_loop().run_until_complete(run_worker()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from time import sleep | ||
|
||
from distributed import Client | ||
from distributed.metrics import time | ||
|
||
from dask_mpi import execute | ||
|
||
|
||
def client_func(): | ||
with Client() as c: | ||
start = time() | ||
while len(c.scheduler_info()["workers"]) != 2: | ||
assert time() < start + 10 | ||
sleep(0.2) | ||
|
||
assert c.submit(lambda x: x + 1, 10).result() == 11 | ||
assert c.submit(lambda x: x + 1, 20, workers=2).result() == 21 | ||
|
||
|
||
if __name__ == "__main__": | ||
execute(client_func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import os | ||
import subprocess | ||
import sys | ||
|
||
import pytest | ||
|
||
pytest.importorskip("mpi4py") | ||
|
||
|
||
def test_basic(mpirun): | ||
script_file = os.path.join( | ||
os.path.dirname(os.path.realpath(__file__)), "execute_basic.py" | ||
) | ||
|
||
p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file]) | ||
|
||
p.communicate() | ||
assert p.returncode == 0 | ||
|
||
|
||
def test_small_world(mpirun): | ||
script_file = os.path.join( | ||
os.path.dirname(os.path.realpath(__file__)), "execute_basic.py" | ||
) | ||
|
||
# Set too few processes to start cluster | ||
p = subprocess.Popen(mpirun + ["-np", "1", sys.executable, script_file]) | ||
|
||
p.communicate() | ||
assert p.returncode != 0 |