Skip to content

Commit

Permalink
fix deadlock on linux and add flood scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
JBlaschke committed Oct 6, 2021
1 parent 04a11ac commit ab600d8
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 14 deletions.
10 changes: 7 additions & 3 deletions examples/consumer_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
for elt in data:
inputs.putf(elt)
p_sum += np.sum(elt)
print(f"{rank=} {p_sum=}")
print(f"{rank=} {p_sum=}", flush=True)

#_______________________________________________________________________________
# (on all other ranks) Take data, and compute the sum locally. The result is
Expand All @@ -69,15 +69,19 @@
res = 0
p_sum = 0.
for i in range(data_size):
# print(f"{rank=} taking", flush=True)
p = inputs.take()
# print(f"{rank=} {p=}", flush=True)
# sleep(random())
if p is not None:
sp = np.sum(p)
result.putf((sp,))
p_sum += sp
res += 1
# else:
# print(f"{rank=} has quit", flush=True)

print(f"{rank=} {res=}")
print(f"{rank=} {res=}", flush=True)

#_______________________________________________________________________________
# (on rank 0) Take `result` elements (local partial sums), and finish the tally.
Expand All @@ -90,4 +94,4 @@
p_sum_r += sp[0]
# print(f"{rank=} {sp=}")

print(f"{rank=} {p_sum_r=} : {p_sum - p_sum_r}")
print(f"{rank=} {p_sum_r=} : {p_sum - p_sum_r}", flush=True)
1 change: 1 addition & 0 deletions mpi_channels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .schedule import Policy, Action, Schedule
from .frame_buffer import FrameBuffer
from .remote_channel import RemoteChannel
3 changes: 3 additions & 0 deletions mpi_channels/frame_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def __init__(self, n_buf, n_mes, dtype=np.float64, host=0):
self.lock()
self._ptr_put((0, 0, 0))
self.unlock()
self._idx = 0
self._max = 0
self._len = 0


def __del__(self):
Expand Down
40 changes: 29 additions & 11 deletions mpi_channels/remote_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
from mpi4py import MPI
from concurrent.futures import ThreadPoolExecutor

from . import FrameBuffer
from . import FrameBuffer, Schedule



class RemoteChannel(object):

def __init__(self, n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4):
def __init__(self,
n_buf, n_mes, dtype=np.float64, host=0,
s_threashold=4, t_wait=0, r_rwait=1):
"""
RemoteChannel(n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4)
RemoteChannel(n_buf, n_mes, dtype=np.float64, host=0,
s_threashold=4, t_wait=0, r_rwait=1)
Create a RemoteChannel containing at most `n_buf` messages (with
maximum message size `n_mes`). Messages are arrays of type `dtype`. The
Expand All @@ -25,10 +28,12 @@ def __init__(self, n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4):
The RemoteChannel's internal state is held in a FrameBuffer object (in
the `buf` attribute).
The RemoteChannel uses a ThreadPoolExecutor to handle futures,
`n_fpool` sets the size of this pool. Future objects resulting from
calls to `putf` are stored in `put_futures`, and Future instances
resulting from calles to `takef` are stored in `take_futures`.
The RemoteChannel uses a ThreadPoolExecutor (of size 1) to handle
futures, The size of this pool needs to be limited to prevent multiple
threads locking each other out of the MPI RMA buffer. For this reason
we choose a size of 1. Future objects resulting from calls to `putf`
are stored in `put_futures`, and Future instances resulting from calles
to `takef` are stored in `take_futures`.
"""
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
Expand All @@ -37,10 +42,13 @@ def __init__(self, n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4):

self.buf = FrameBuffer(n_buf, n_mes, dtype=self.dtype, host=self.host)

self.pool = ThreadPoolExecutor(5)
self.pool = ThreadPoolExecutor(1)
self.put_futures = list()
self.take_futures = list()

self.schedule = Schedule(threshold=128)


def put(self, src):
"""
put(src)
Expand All @@ -53,14 +61,19 @@ def put(self, src):
Blocks if buffer is full. For non-blocking version see `putf`
"""
while True:
# print(f"{self.rank=} start putting 1 {self.buf.idx=} {self.buf.max=} {self.buf.len=}", flush=True)
self.buf.lock()
# print(f"{self.rank=} start putting 2", flush=True)
self.buf.sync()
# print(f"{self.rank=} start putting 3", flush=True)

# Check if there is space in the buffer for new elements. If the
# buffer is full, spin and watch for space
# print(f"putting {self.buf.idx=} {self.buf.max=} {self.buf.len=}")
if self.buf.max - self.buf.idx >= self.buf.n_buf:
self.buf.unlock()
# print(f"PUT {self.buf.max=} {self.buf.idx=} {self.buf.n_buf=}", flush=True)
self.schedule()
continue

self.buf.put(src)
Expand Down Expand Up @@ -122,22 +135,28 @@ def take(self):
Blocks if buffer is empty. For non-blocking version see `takef`
"""
while True:
# print(f"{self.rank=} start taking 1 {self.buf.idx=} {self.buf.max=} {self.buf.len=}", flush=True)
self.buf.lock()
# print(f"{self.rank=} start taking 2", flush=True)
self.buf.sync()
# print(f"{self.rank=} start taking 3", flush=True)

if self.buf.idx >= self.buf.len:
self.buf.unlock()
# print(f"Overrunning Src {self.buf.idx=}, {self.buf.len=}")
# print(f"TAKE {self.buf.idx=}, {self.buf.len=}", flush=True)
return None

if self.buf.idx >= self.buf.max:
# print(f"{self.rank=} peeking {src_offset=} {src_capacity=} {src_len=}")
self.buf.unlock()
# print(f"TAKE {self.rank=} peeking {self.buf.idx=} {self.buf.max=} {self.buf.len=}", flush=True)
self.schedule()
continue

buf = self.buf.take()
# print(f"{self.rank=} taking", flush=True)
# print(f"{self.rank=} taking {src_offset=}, {src_capacity=}, {src_len=}")
self.buf.unlock()
# print(f"{self.rank=} done taking", flush=True)

return buf

Expand All @@ -157,4 +176,3 @@ def takef(self):
)

return self.take_futures[-1]

54 changes: 54 additions & 0 deletions mpi_channels/schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-


from random import random
from enum import Enum, auto
from time import sleep



class Policy(Enum):
WAIT = auto()
KILL = auto()


class Action(Enum):
CONTINUE = auto()
WAIT = auto()
KILL = auto()


class Schedule(object):

def __init__(self, threshold, t_wait=0, r_wait=1, policy=Policy.WAIT):
self._threshold = threshold
self._t_wait = t_wait
self._r_wait = r_wait
self._policy = policy
self._reset()

@property
def count(self):
return self._count

@property
def wait(self):
return self._t_wait + self._r_wait*random()

def _reset(self):
self._count = 0

def __call__(self):
self._count += 1

if self._count > self._threshold:
self._reset()
if self._policy == Policy.WAIT:
# print("wating", flush=True)
sleep(self.wait)
return Action.WAIT
if self._policy == Policy.KILL:
return Action.KILL
else:
return Action.CONTINUE

0 comments on commit ab600d8

Please sign in to comment.