diff --git a/examples/consumer_futures.py b/examples/consumer_futures.py index d6abe52..af41dcb 100644 --- a/examples/consumer_futures.py +++ b/examples/consumer_futures.py @@ -59,7 +59,7 @@ for elt in data: inputs.putf(elt) p_sum += np.sum(elt) - print(f"{rank=} {p_sum=}") + print(f"{rank=} {p_sum=}", flush=True) #_______________________________________________________________________________ # (on all other ranks) Take data, and compute the sum locally. The result is @@ -69,15 +69,19 @@ res = 0 p_sum = 0. for i in range(data_size): + # print(f"{rank=} taking", flush=True) p = inputs.take() + # print(f"{rank=} {p=}", flush=True) # sleep(random()) if p is not None: sp = np.sum(p) result.putf((sp,)) p_sum += sp res += 1 + # else: + # print(f"{rank=} has quit", flush=True) - print(f"{rank=} {res=}") + print(f"{rank=} {res=}", flush=True) #_______________________________________________________________________________ # (on rank 0) Take `result` elements (local partial sums), and finish the tally. @@ -90,4 +94,4 @@ p_sum_r += sp[0] # print(f"{rank=} {sp=}") - print(f"{rank=} {p_sum_r=} : {p_sum - p_sum_r}") + print(f"{rank=} {p_sum_r=} : {p_sum - p_sum_r}", flush=True) diff --git a/mpi_channels/__init__.py b/mpi_channels/__init__.py index 82013ed..83542e4 100644 --- a/mpi_channels/__init__.py +++ b/mpi_channels/__init__.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from .schedule import Policy, Action, Schedule from .frame_buffer import FrameBuffer from .remote_channel import RemoteChannel diff --git a/mpi_channels/frame_buffer.py b/mpi_channels/frame_buffer.py index 961fd1a..6c74d0e 100644 --- a/mpi_channels/frame_buffer.py +++ b/mpi_channels/frame_buffer.py @@ -89,6 +89,9 @@ def __init__(self, n_buf, n_mes, dtype=np.float64, host=0): self.lock() self._ptr_put((0, 0, 0)) self.unlock() + self._idx = 0 + self._max = 0 + self._len = 0 def __del__(self): diff --git a/mpi_channels/remote_channel.py b/mpi_channels/remote_channel.py index d3a5317..1d02009 100644 --- a/mpi_channels/remote_channel.py +++ b/mpi_channels/remote_channel.py @@ -8,15 +8,18 @@ from mpi4py import MPI from concurrent.futures import ThreadPoolExecutor -from . import FrameBuffer +from . import FrameBuffer, Schedule class RemoteChannel(object): - def __init__(self, n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4): + def __init__(self, + n_buf, n_mes, dtype=np.float64, host=0, + s_threashold=4, t_wait=0, r_rwait=1): """ - RemoteChannel(n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4) + RemoteChannel(n_buf, n_mes, dtype=np.float64, host=0, + s_threashold=4, t_wait=0, r_rwait=1) Create a RemoteChannel containing at most `n_buf` messages (with maximum message size `n_mes`). Messages are arrays of type `dtype`. The @@ -25,10 +28,12 @@ def __init__(self, n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4): The RemoteChannel's internal state is held in a FrameBuffer object (in the `buf` attribute). - The RemoteChannel uses a ThreadPoolExecutor to handle futures, - `n_fpool` sets the size of this pool. Future objects resulting from - calls to `putf` are stored in `put_futures`, and Future instances - resulting from calles to `takef` are stored in `take_futures`. + The RemoteChannel uses a ThreadPoolExecutor (of size 1) to handle + futures, The size of this pool needs to be limited to prevent multiple + threads locking each other out of the MPI RMA buffer. For this reason + we choose a size of 1. Future objects resulting from calls to `putf` + are stored in `put_futures`, and Future instances resulting from calles + to `takef` are stored in `take_futures`. """ self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() @@ -37,10 +42,13 @@ def __init__(self, n_buf, n_mes, dtype=np.float64, host=0, n_fpool=4): self.buf = FrameBuffer(n_buf, n_mes, dtype=self.dtype, host=self.host) - self.pool = ThreadPoolExecutor(5) + self.pool = ThreadPoolExecutor(1) self.put_futures = list() self.take_futures = list() + self.schedule = Schedule(threshold=128) + + def put(self, src): """ put(src) @@ -53,14 +61,19 @@ def put(self, src): Blocks if buffer is full. For non-blocking version see `putf` """ while True: + # print(f"{self.rank=} start putting 1 {self.buf.idx=} {self.buf.max=} {self.buf.len=}", flush=True) self.buf.lock() + # print(f"{self.rank=} start putting 2", flush=True) self.buf.sync() + # print(f"{self.rank=} start putting 3", flush=True) # Check if there is space in the buffer for new elements. If the # buffer is full, spin and watch for space # print(f"putting {self.buf.idx=} {self.buf.max=} {self.buf.len=}") if self.buf.max - self.buf.idx >= self.buf.n_buf: self.buf.unlock() + # print(f"PUT {self.buf.max=} {self.buf.idx=} {self.buf.n_buf=}", flush=True) + self.schedule() continue self.buf.put(src) @@ -122,22 +135,28 @@ def take(self): Blocks if buffer is empty. For non-blocking version see `takef` """ while True: + # print(f"{self.rank=} start taking 1 {self.buf.idx=} {self.buf.max=} {self.buf.len=}", flush=True) self.buf.lock() + # print(f"{self.rank=} start taking 2", flush=True) self.buf.sync() + # print(f"{self.rank=} start taking 3", flush=True) if self.buf.idx >= self.buf.len: self.buf.unlock() - # print(f"Overrunning Src {self.buf.idx=}, {self.buf.len=}") + # print(f"TAKE {self.buf.idx=}, {self.buf.len=}", flush=True) return None if self.buf.idx >= self.buf.max: - # print(f"{self.rank=} peeking {src_offset=} {src_capacity=} {src_len=}") self.buf.unlock() + # print(f"TAKE {self.rank=} peeking {self.buf.idx=} {self.buf.max=} {self.buf.len=}", flush=True) + self.schedule() continue buf = self.buf.take() + # print(f"{self.rank=} taking", flush=True) # print(f"{self.rank=} taking {src_offset=}, {src_capacity=}, {src_len=}") self.buf.unlock() + # print(f"{self.rank=} done taking", flush=True) return buf @@ -157,4 +176,3 @@ def takef(self): ) return self.take_futures[-1] - diff --git a/mpi_channels/schedule.py b/mpi_channels/schedule.py new file mode 100644 index 0000000..f84b67f --- /dev/null +++ b/mpi_channels/schedule.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +from random import random +from enum import Enum, auto +from time import sleep + + + +class Policy(Enum): + WAIT = auto() + KILL = auto() + + +class Action(Enum): + CONTINUE = auto() + WAIT = auto() + KILL = auto() + + +class Schedule(object): + + def __init__(self, threshold, t_wait=0, r_wait=1, policy=Policy.WAIT): + self._threshold = threshold + self._t_wait = t_wait + self._r_wait = r_wait + self._policy = policy + self._reset() + + @property + def count(self): + return self._count + + @property + def wait(self): + return self._t_wait + self._r_wait*random() + + def _reset(self): + self._count = 0 + + def __call__(self): + self._count += 1 + + if self._count > self._threshold: + self._reset() + if self._policy == Policy.WAIT: + # print("wating", flush=True) + sleep(self.wait) + return Action.WAIT + if self._policy == Policy.KILL: + return Action.KILL + else: + return Action.CONTINUE