Skip to content

Commit

Permalink
Raise an exception if the comm world is too small (#107)
Browse files Browse the repository at this point in the history
* Raise an exeption if the comm world is too small

* Expose exception at module level

* Tweak exceptions to be more readable

* Tweak language a little more
  • Loading branch information
jacobtomlinson authored Sep 21, 2023
1 parent ff148aa commit a8890e6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions dask_mpi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._version import get_versions
from .core import initialize, send_close_signal
from .exceptions import WorldTooSmallException

__version__ = get_versions()["version"]
del get_versions
10 changes: 10 additions & 0 deletions dask_mpi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from distributed.utils import import_term
from mpi4py import MPI

from .exceptions import WorldTooSmallException


@click.command()
@click.argument("scheduler_address", type=str, required=False)
Expand Down Expand Up @@ -95,6 +97,14 @@ def main(
name,
):
comm = MPI.COMM_WORLD

world_size = comm.Get_size()
if scheduler and world_size < 2:
raise WorldTooSmallException(
f"Not enough MPI ranks to start cluster, found {world_size}, "
"needs at least 2, one each for the scheduler and a worker."
)

rank = comm.Get_rank()

try:
Expand Down
9 changes: 9 additions & 0 deletions dask_mpi/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from distributed import Client, Nanny, Scheduler
from distributed.utils import import_term

from .exceptions import WorldTooSmallException


def initialize(
interface=None,
Expand Down Expand Up @@ -74,6 +76,13 @@ def initialize(

comm = MPI.COMM_WORLD

world_size = comm.Get_size()
if world_size < 3:
raise WorldTooSmallException(
f"Not enough MPI ranks to start cluster, found {world_size}, "
"needs at least 3, one each for the scheduler, client and a worker."
)

rank = comm.Get_rank()

if not worker_options:
Expand Down
2 changes: 2 additions & 0 deletions dask_mpi/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class WorldTooSmallException(RuntimeError):
"""Not enough MPI ranks to start all required processes."""
18 changes: 18 additions & 0 deletions dask_mpi/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def test_basic(loop, worker_class, mpirun):
assert c.submit(lambda x: x + 1, 10).result() == 11


def test_small_world(mpirun):
with tmpfile(extension="json") as fn:
# Set too few processes to start cluster
p = subprocess.Popen(
mpirun
+ [
"-np",
"1",
"dask-mpi",
"--scheduler-file",
fn,
]
)

p.communicate()
assert p.returncode != 0


def test_no_scheduler(loop, mpirun):
with tmpfile(extension="json") as fn:
cmd = mpirun + ["-np", "2", "dask-mpi", "--scheduler-file", fn]
Expand Down
12 changes: 12 additions & 0 deletions dask_mpi/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,15 @@ def test_basic(mpirun):

p.communicate()
assert p.returncode == 0


def test_small_world(mpirun):
script_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "core_basic.py"
)

# Set too few processes to start cluster
p = subprocess.Popen(mpirun + ["-np", "1", sys.executable, script_file])

p.communicate()
assert p.returncode != 0

0 comments on commit a8890e6

Please sign in to comment.