-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
pyzmq
, DAQJobRemote
and tests for it
- `DAQJobRemote` is a DAQJob that connects two seperate ENRGDAQ instances
- Loading branch information
1 parent
e9700c1
commit 8d82dc9
Showing
7 changed files
with
235 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import pickle | ||
import threading | ||
import time | ||
from dataclasses import dataclass | ||
|
||
import zmq | ||
|
||
from daq.base import DAQJob | ||
from daq.models import DAQJobConfig, DAQJobMessage | ||
|
||
|
||
@dataclass | ||
class DAQJobRemoteConfig(DAQJobConfig): | ||
zmq_sender_url: str | ||
zmq_receiver_url: str | ||
|
||
|
||
class DAQJobRemote(DAQJob): | ||
""" | ||
DAQJobRemote is a DAQJob that connects two seperate ENRGDAQ instances. | ||
It sends to and receives from a remote ENRGDAQ, in such that: | ||
- message_in -> remote message_out | ||
- remote message_in -> message_out | ||
""" | ||
|
||
allowed_message_in_types = [DAQJobMessage] # accept all message types | ||
config = DAQJobRemoteConfig | ||
|
||
def __init__(self, config: DAQJobRemoteConfig): | ||
super().__init__(config) | ||
self._zmq_context = zmq.Context() | ||
self._zmq_sender = self._zmq_context.socket(zmq.PUSH) | ||
self._zmq_receiver = self._zmq_context.socket(zmq.PULL) | ||
self._zmq_sender.connect(config.zmq_sender_url) | ||
self._zmq_receiver.connect(config.zmq_receiver_url) | ||
|
||
self._receive_thread = threading.Thread( | ||
target=self._start_receive_thread, daemon=True | ||
) | ||
|
||
def handle_message(self, message: DAQJobMessage) -> bool: | ||
self._zmq_sender.send(pickle.dumps(message)) | ||
return True | ||
|
||
def _start_receive_thread(self): | ||
while True: | ||
message = self._zmq_receiver.recv() | ||
# remote message_in -> message_out | ||
self.message_out.put(pickle.loads(message)) | ||
|
||
def start(self): | ||
self._receive_thread.start() | ||
|
||
while True: | ||
if not self._receive_thread.is_alive(): | ||
raise RuntimeError("receive thread died") | ||
# message_in -> remote message_out | ||
self.consume() | ||
time.sleep(0.1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pickle | ||
import threading | ||
import time | ||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
|
||
from daq.jobs.remote import DAQJobRemote, DAQJobRemoteConfig | ||
from daq.jobs.test_job import DAQJobTest | ||
from daq.store.models import DAQJobMessageStore | ||
|
||
|
||
class TestDAQJobRemote(unittest.TestCase): | ||
@patch("daq.jobs.remote.zmq.Context") | ||
def setUp(self, MockZmqContext): | ||
self.mock_context = MockZmqContext.return_value | ||
self.mock_sender = self.mock_context.socket.return_value | ||
self.mock_receiver = self.mock_context.socket.return_value | ||
self.config = DAQJobRemoteConfig( | ||
daq_job_type="remote", | ||
zmq_sender_url="tcp://localhost:5555", | ||
zmq_receiver_url="tcp://localhost:5556", | ||
) | ||
self.daq_job_remote = DAQJobRemote(self.config) | ||
self.daq_job_remote._zmq_sender = self.mock_sender | ||
self.daq_job_remote._zmq_receiver = self.mock_receiver | ||
|
||
def test_handle_message(self): | ||
message = DAQJobMessageStore( | ||
store_config={}, data=[], keys=[], daq_job=DAQJobTest({}) | ||
) | ||
self.daq_job_remote.handle_message(message) | ||
self.mock_sender.send.assert_called_once_with(pickle.dumps(message)) | ||
|
||
def test_start(self): | ||
mock_receive_thread = MagicMock() | ||
|
||
def stop_receive_thread(): | ||
time.sleep(0.1) | ||
mock_receive_thread.is_alive.return_value = False | ||
|
||
self.daq_job_remote._receive_thread = mock_receive_thread | ||
threading.Thread(target=stop_receive_thread, daemon=True).start() | ||
|
||
with self.assertRaises(RuntimeError): | ||
self.daq_job_remote.start() | ||
|
||
def test_receive_thread(self): | ||
message = DAQJobMessageStore(store_config={}, data=[], keys=[], daq_job=None) # type: ignore | ||
self.daq_job_remote.message_out = MagicMock() | ||
|
||
call_count = 0 | ||
|
||
def side_effect(): | ||
nonlocal call_count | ||
call_count += 1 | ||
if call_count >= 2: | ||
raise Exception("Stop receive thread") | ||
return pickle.dumps(message) | ||
|
||
self.mock_receiver.recv.side_effect = side_effect | ||
|
||
with self.assertRaises(Exception): | ||
self.daq_job_remote._start_receive_thread() | ||
self.daq_job_remote.message_out.put.assert_called_once_with(message) | ||
self.assertEqual(self.daq_job_remote.message_out.put.call_count, 1) | ||
self.assertEqual(call_count, 2) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.