Skip to content

Commit

Permalink
feat: add DAQJobRemoteProxy and its tests, refactor DAQJobRemote (#7
Browse files Browse the repository at this point in the history
)
  • Loading branch information
furkan-bilgin committed Nov 30, 2024
1 parent 880dde3 commit d1e626a
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 19 deletions.
42 changes: 25 additions & 17 deletions src/enrgdaq/daq/jobs/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ class DAQJobRemoteConfig(DAQJobConfig):
topics (list[str]): List of topics to subscribe to.
"""

zmq_local_url: str
zmq_remote_urls: list[str]
zmq_proxy_sub_urls: list[str]
topics: list[str] = []
use_xsub: bool = False
zmq_proxy_pub_url: Optional[str] = None


class DAQJobRemote(DAQJob):
Expand All @@ -46,10 +47,6 @@ class DAQJobRemote(DAQJob):
config_type (type): Configuration type for the job.
config (DAQJobRemoteConfig): Configuration instance.
restart_offset (timedelta): Restart offset time.
_zmq_pub_ctx (zmq.Context): ZMQ context for publishing.
_zmq_sub_ctx (zmq.Context): ZMQ context for subscribing.
_zmq_pub (zmq.Socket): ZMQ socket for publishing.
_zmq_sub (Optional[zmq.Socket]): ZMQ socket for subscribing.
_message_class_cache (dict): Cache for message classes.
_remote_message_ids (set): Set of remote message IDs.
_receive_thread (threading.Thread): Thread for receiving messages.
Expand All @@ -60,27 +57,25 @@ class DAQJobRemote(DAQJob):
config: DAQJobRemoteConfig
restart_offset = timedelta(seconds=5)

_zmq_pub_ctx: zmq.Context
_zmq_sub_ctx: zmq.Context

_zmq_pub: zmq.Socket
_zmq_sub: Optional[zmq.Socket]
_message_class_cache: dict[str, type[DAQJobMessage]]
_remote_message_ids: set[str]
_receive_thread: threading.Thread

def __init__(self, config: DAQJobRemoteConfig, **kwargs):
super().__init__(config, **kwargs)

self._zmq_pub_ctx = zmq.Context()
self._logger.debug(f"Listening on {config.zmq_local_url}")
self._zmq_pub = self._zmq_pub_ctx.socket(zmq.PUB)
self._zmq_pub.bind(config.zmq_local_url)
if config.zmq_proxy_pub_url is not None:
self._zmq_pub_ctx = zmq.Context()
self._zmq_pub = self._zmq_pub_ctx.socket(zmq.PUB)
self._zmq_pub.connect(config.zmq_proxy_pub_url)
else:
self._zmq_pub_ctx = None
self._zmq_pub = None
self._zmq_sub = None

self._receive_thread = threading.Thread(
target=self._start_receive_thread,
args=(config.zmq_remote_urls,),
args=(config.zmq_proxy_sub_urls,),
daemon=True,
)
self._message_class_cache = {}
Expand All @@ -100,6 +95,8 @@ def handle_message(self, message: DAQJobMessage) -> bool:
or not super().handle_message(message)
# Ignore if the message is remote, meaning it was sent by another Supervisor
or message.is_remote
# Ignore if we are not connected to the proxy
or self._zmq_pub is None
):
return True # Silently ignore

Expand All @@ -120,7 +117,7 @@ def _create_zmq_sub(self, remote_urls: list[str]) -> zmq.Socket:
"""
Create a ZMQ subscriber socket.
Args:
Args:g
remote_urls (list[str]): List of remote URLs to connect to.
Returns:
Expand Down Expand Up @@ -157,6 +154,17 @@ def _start_receive_thread(self, remote_urls: list[str]):
except zmq.ContextTerminated:
break
recv_message = self._unpack_message(message)
if (
recv_message.daq_job_info is not None
and recv_message.daq_job_info.supervisor_config is not None
and self.info.supervisor_config is not None
and recv_message.daq_job_info.supervisor_config.supervisor_id
== self.info.supervisor_config.supervisor_id
):
self._logger.warning(
f"Received own message '{type(recv_message).__name__}' on topic '{topic.decode()}', ignoring message. This should NOT happen. Check the config."
)
continue
self._logger.debug(
f"Received {len(message)} bytes for message '{type(recv_message).__name__}' on topic '{topic.decode()}'"
)
Expand Down
69 changes: 69 additions & 0 deletions src/enrgdaq/daq/jobs/remote_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import zmq

from enrgdaq.daq.base import DAQJob
from enrgdaq.daq.models import DAQJobConfig


class DAQJobRemoteProxyConfig(DAQJobConfig):
"""
Configuration for DAQJobRemoteProxy.
Attributes:
zmq_xsub_url (str): ZMQ xsub URL.
zmq_xpub_url (str): ZMQ xpub URL.
"""

zmq_xsub_url: str
zmq_xpub_url: str


class DAQJobRemoteProxy(DAQJob):
"""
DAQJobRemoteProxy is a DAQJob that acts as a proxy between two ZMQ sockets.
It uses zmq.proxy to forward messages between xsub and xpub.
pub -> xsub -> xpub -> sub
When you want to the DAQJobRemoteProxy:
- For pub, connect to xsub
- For sub, connect to xpub
Attributes:
config_type (type): Configuration type for the job.
config (DAQJobRemoteProxyConfig): Configuration instance.
"""

config_type = DAQJobRemoteProxyConfig
config: DAQJobRemoteProxyConfig

def __init__(self, config: DAQJobRemoteProxyConfig, **kwargs):
super().__init__(config, **kwargs)

self._zmq_ctx = zmq.Context()
self._xsub_sock = self._zmq_ctx.socket(zmq.XSUB)
self._xsub_sock.bind(config.zmq_xsub_url)

self._xpub_sock = self._zmq_ctx.socket(zmq.XPUB)
self._xpub_sock.bind(config.zmq_xpub_url)

self._logger.info(
f"Proxying between {config.zmq_xsub_url} -> {config.zmq_xpub_url}"
)

def start(self):
"""
Start the ZMQ proxy.
"""
try:
zmq.proxy(self._xsub_sock, self._xpub_sock)
except zmq.ContextTerminated:
pass

def __del__(self):
"""
Destructor for DAQJobRemoteProxy.
"""
if getattr(self, "_zmq_ctx", None) is not None:
self._zmq_ctx.destroy()

return super().__del__()
4 changes: 2 additions & 2 deletions src/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def setUp(self, MockZmqContext):
self.mock_receiver = self.mock_context.socket.return_value
self.config = DAQJobRemoteConfig(
daq_job_type="remote",
zmq_local_url="tcp://localhost:5555",
zmq_remote_urls=["tcp://localhost:5556"],
zmq_proxy_pub_url="tcp://localhost:5555",
zmq_proxy_sub_urls=["tcp://localhost:5556"],
)
self.daq_job_remote = DAQJobRemote(self.config)
self.daq_job_remote._zmq_pub = self.mock_sender
Expand Down
51 changes: 51 additions & 0 deletions src/tests/test_remote_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import unittest
from unittest.mock import patch

import zmq

from enrgdaq.daq.jobs.remote_proxy import DAQJobRemoteProxy, DAQJobRemoteProxyConfig


class TestDAQJobRemoteProxy(unittest.TestCase):
@patch("enrgdaq.daq.jobs.remote_proxy.zmq.Context")
def setUp(self, MockZmqContext):
self.mock_context = MockZmqContext.return_value
self.config = DAQJobRemoteProxyConfig(
daq_job_type="remote_proxy",
zmq_xsub_url="tcp://localhost:5557",
zmq_xpub_url="tcp://localhost:5558",
)
self.daq_job_remote_proxy = DAQJobRemoteProxy(self.config)

def test_initialization(self):
self.mock_context.socket.assert_any_call(zmq.XSUB)
self.mock_context.socket.assert_any_call(zmq.XPUB)
self.mock_context.socket.return_value.bind.assert_any_call(
"tcp://localhost:5557"
)
self.mock_context.socket.return_value.bind.assert_any_call(
"tcp://localhost:5558"
)
self.assertEqual(self.daq_job_remote_proxy.config, self.config)
self.assertEqual(
self.daq_job_remote_proxy._xsub_sock, self.mock_context.socket.return_value
)
self.assertEqual(
self.daq_job_remote_proxy._xpub_sock, self.mock_context.socket.return_value
)

@patch("enrgdaq.daq.jobs.remote_proxy.zmq.proxy")
def test_start(self, mock_zmq_proxy):
self.daq_job_remote_proxy.start()
mock_zmq_proxy.assert_called_once_with(
self.daq_job_remote_proxy._xsub_sock,
self.daq_job_remote_proxy._xpub_sock,
)

def test_destructor(self):
del self.daq_job_remote_proxy
self.mock_context.destroy.assert_called_once()


if __name__ == "__main__":
unittest.main()

0 comments on commit d1e626a

Please sign in to comment.