Skip to content

Commit

Permalink
feat: allow subscribing to multiple remotes in DAQJobRemote
Browse files Browse the repository at this point in the history
- also bug fix in `handle_messages`
  • Loading branch information
furkan-bilgin committed Oct 21, 2024
1 parent a22df53 commit 73ed894
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
2 changes: 1 addition & 1 deletion configs/examples/remote.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
daq_job_type = "remote"
zmq_local_url = "tcp://localhost:10191"
zmq_remote_url = "tcp://1.2.3.4:10191"
zmq_remote_urls = ["tcp://1.2.3.4:10191"]
56 changes: 38 additions & 18 deletions src/daq/jobs/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from daq.models import DAQJobConfig, DAQJobMessage
from daq.store.models import DAQJobMessageStore

DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT = 1000
DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT = 10000


@dataclass
class DAQJobRemoteConfig(DAQJobConfig):
zmq_local_url: str
zmq_remote_url: str
zmq_remote_urls: list[str]


class DAQJobRemote(DAQJob):
Expand All @@ -26,28 +26,40 @@ class DAQJobRemote(DAQJob):
- message_in -> remote message_out
- remote message_in -> message_out
TODO: Use zmq CURVE security
"""

allowed_message_in_types = [DAQJobMessage] # accept all message types
config_type = DAQJobRemoteConfig
config: DAQJobRemoteConfig
_zmq_local: zmq.Socket
_zmq_remote: zmq.Socket
_message_class_cache: dict
_zmq_remotes: dict[str, zmq.Socket]
_message_class_cache: dict[str, type[DAQJobMessage]]
_remote_message_ids: set[str]
_receive_threads: dict[str, threading.Thread]

def __init__(self, config: DAQJobRemoteConfig):
super().__init__(config)
self._zmq_context = zmq.Context()
self._zmq_local = self._zmq_context.socket(zmq.PUSH)
self._zmq_remote = self._zmq_context.socket(zmq.PULL)
self._zmq_local.connect(config.zmq_local_url)
self._zmq_remote.connect(config.zmq_remote_url)
self._logger.debug(f"Listening on {config.zmq_local_url}")
self._zmq_local = self._zmq_context.socket(zmq.PUB)
self._zmq_remotes = {}
self._zmq_local.bind(config.zmq_local_url)

self._receive_threads = {}
for remote_url in config.zmq_remote_urls:
self._logger.debug(f"Connecting to {remote_url}")
zmq_remote = self._zmq_context.socket(zmq.SUB)
zmq_remote.connect(remote_url)
self._zmq_remotes[remote_url] = zmq_remote
self._receive_threads[remote_url] = threading.Thread(
target=self._start_receive_thread,
args=(remote_url, zmq_remote),
daemon=True,
)
self._message_class_cache = {}

self._receive_thread = threading.Thread(
target=self._start_receive_thread, daemon=True
)
self._message_class_cache = {
x.__name__: x for x in DAQJobMessage.__subclasses__()
}
Expand All @@ -62,25 +74,26 @@ def handle_message(self, message: DAQJobMessage) -> bool:
or message.id in self._remote_message_ids
or not super().handle_message(message)
):
return False
return True # Silently ignore

self._zmq_local.send(self._pack_message(message))
return True

def _start_receive_thread(self):
def _start_receive_thread(self, remote_url: str, zmq_remote: zmq.Socket):
while True:
message = self._zmq_remote.recv()
message = zmq_remote.recv()
self._logger.debug(
f"Received {len(message)} bytes from remote ({self.config.zmq_remote_url})"
f"Received {len(message)} bytes from remote ({remote_url})"
)
# remote message_in -> message_out
self.message_out.put(self._unpack_message(message))

def start(self):
self._receive_thread.start()
for remote_url in self._zmq_remotes.keys():
self._receive_threads[remote_url].start()

while True:
if not self._receive_thread.is_alive():
if not any(x.is_alive() for x in self._receive_threads.values()):
raise RuntimeError("Receive thread died")
# message_in -> remote message_out
self.consume()
Expand All @@ -98,7 +111,7 @@ def _unpack_message(self, message: bytes) -> DAQJobMessage:

message_class = self._message_class_cache[message_type]

res: DAQJobMessage = message_class.from_json(data)
res = message_class.from_json(data)
if res.id is None:
raise Exception("Message id is not set")

Expand All @@ -107,3 +120,10 @@ def _unpack_message(self, message: bytes) -> DAQJobMessage:
self._remote_message_ids.pop()
self._logger.debug(f"Unpacked message {message_type} ({res.id})")
return res

def __del__(self):
for remote_url in self._zmq_remotes.keys():
self._zmq_remotes[remote_url].close()
self._zmq_local.close()

return super().__del__()
12 changes: 8 additions & 4 deletions src/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def setUp(self, MockZmqContext):
self.config = DAQJobRemoteConfig(
daq_job_type="remote",
zmq_local_url="tcp://localhost:5555",
zmq_remote_url="tcp://localhost:5556",
zmq_remote_urls=["tcp://localhost:5556"],
)
self.daq_job_remote = DAQJobRemote(self.config)
self.daq_job_remote._zmq_local = self.mock_sender
self.daq_job_remote._zmq_remote = self.mock_receiver
self.daq_job_remote._zmq_remotes = {"tcp://localhost:5556": self.mock_receiver}

def test_handle_message(self):
message = DAQJobMessage(
Expand All @@ -40,7 +40,9 @@ def stop_receive_thread():
time.sleep(0.1)
mock_receive_thread.is_alive.return_value = False

self.daq_job_remote._receive_thread = mock_receive_thread
self.daq_job_remote._receive_threads = {
"tcp://localhost:5556": mock_receive_thread
}
threading.Thread(target=stop_receive_thread, daemon=True).start()

with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -68,7 +70,9 @@ def side_effect():
self.mock_receiver.recv.side_effect = side_effect

with self.assertRaises(RuntimeError):
self.daq_job_remote._start_receive_thread()
self.daq_job_remote._start_receive_thread(
"tcp://localhost:5556", self.mock_receiver
)
self.daq_job_remote.message_out.put.assert_called_once_with(message)
self.assertEqual(self.daq_job_remote.message_out.put.call_count, 1)
self.assertEqual(call_count, 2)
Expand Down

0 comments on commit 73ed894

Please sign in to comment.