Skip to content

Commit

Permalink
fix: fix DAQJobRemote bug and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
furkan-bilgin committed Nov 6, 2024
1 parent 1524bdf commit d4bad71
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 27 deletions.
42 changes: 22 additions & 20 deletions src/daq/jobs/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DAQJobRemote(DAQJob):
_zmq_remotes: dict[str, zmq.Socket]
_message_class_cache: dict[str, type[DAQJobMessage]]
_remote_message_ids: set[str]
_receive_threads: dict[str, threading.Thread]
_receive_thread: threading.Thread

def __init__(self, config: DAQJobRemoteConfig):
super().__init__(config)
Expand All @@ -47,17 +47,11 @@ def __init__(self, config: DAQJobRemoteConfig):
self._zmq_remotes = {}
self._zmq_local.bind(config.zmq_local_url)

self._receive_threads = {}
for remote_url in config.zmq_remote_urls:
self._logger.debug(f"Connecting to {remote_url}")
zmq_remote = self._zmq_context.socket(zmq.SUB)
zmq_remote.connect(remote_url)
self._zmq_remotes[remote_url] = zmq_remote
self._receive_threads[remote_url] = threading.Thread(
target=self._start_receive_thread,
args=(remote_url, zmq_remote),
daemon=True,
)
self._receive_thread = threading.Thread(
target=self._start_receive_thread,
args=(config.zmq_remote_urls,),
daemon=True,
)
self._message_class_cache = {}

self._message_class_cache = {
Expand All @@ -77,23 +71,31 @@ def handle_message(self, message: DAQJobMessage) -> bool:
self._zmq_local.send(self._pack_message(message))
return True

def _start_receive_thread(self, remote_url: str, zmq_remote: zmq.Socket):
def _create_zmq_sub(self, remote_urls: list[str]):
ctx = zmq.Context()
zmq_sub = ctx.socket(zmq.SUB)
for remote_url in remote_urls:
self._logger.debug(f"Connecting to {remote_url}")
zmq_sub.connect(remote_url)
zmq_sub.subscribe("")
return zmq_sub

def _start_receive_thread(self, remote_urls: list[str]):
zmq_sub = self._create_zmq_sub(remote_urls)

while True:
message = zmq_remote.recv()
self._logger.debug(
f"Received {len(message)} bytes from remote ({remote_url})"
)
message = zmq_sub.recv()
self._logger.debug(f"Received {len(message)} bytes from")
recv_message = self._unpack_message(message)
recv_message.is_remote = True
# remote message_in -> message_out
self.message_out.put(recv_message)

def start(self):
for remote_url in self._zmq_remotes.keys():
self._receive_threads[remote_url].start()
self._receive_thread.start()

while True:
if not any(x.is_alive() for x in self._receive_threads.values()):
if not self._receive_thread.is_alive():
raise RuntimeError("Receive thread died")
# message_in -> remote message_out
self.consume()
Expand Down
11 changes: 4 additions & 7 deletions src/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def setUp(self, MockZmqContext):
)
self.daq_job_remote = DAQJobRemote(self.config)
self.daq_job_remote._zmq_local = self.mock_sender
self.daq_job_remote._zmq_remotes = {"tcp://localhost:5556": self.mock_receiver}

def test_handle_message(self):
message = DAQJobMessage(
Expand All @@ -41,9 +40,7 @@ def stop_receive_thread():
time.sleep(0.1)
mock_receive_thread.is_alive.return_value = False

self.daq_job_remote._receive_threads = {
"tcp://localhost:5556": mock_receive_thread
}
self.daq_job_remote._receive_thread = mock_receive_thread
threading.Thread(target=stop_receive_thread, daemon=True).start()

with self.assertRaises(RuntimeError):
Expand All @@ -70,12 +67,12 @@ def side_effect():
raise RuntimeError("Stop receive thread")
return self.daq_job_remote._pack_message(message)

self.daq_job_remote._create_zmq_sub = MagicMock(return_value=self.mock_receiver)
self.mock_receiver.recv.side_effect = side_effect

with self.assertRaises(RuntimeError):
self.daq_job_remote._start_receive_thread(
"tcp://localhost:5556", self.mock_receiver
)
self.daq_job_remote._start_receive_thread(["tcp://localhost:5556"])

assert_msg = message
assert_msg.is_remote = True
self.daq_job_remote.message_out.put.assert_called_once_with(assert_msg)
Expand Down

0 comments on commit d4bad71

Please sign in to comment.