From d1e626a27bfd1b97fffe168b261fe9b313210df5 Mon Sep 17 00:00:00 2001 From: Furkan Date: Sat, 30 Nov 2024 20:13:32 +0300 Subject: [PATCH] feat: add `DAQJobRemoteProxy` and its tests, refactor `DAQJobRemote` (#7) --- src/enrgdaq/daq/jobs/remote.py | 42 ++++++++++------- src/enrgdaq/daq/jobs/remote_proxy.py | 69 ++++++++++++++++++++++++++++ src/tests/test_remote.py | 4 +- src/tests/test_remote_proxy.py | 51 ++++++++++++++++++++ 4 files changed, 147 insertions(+), 19 deletions(-) create mode 100644 src/enrgdaq/daq/jobs/remote_proxy.py create mode 100644 src/tests/test_remote_proxy.py diff --git a/src/enrgdaq/daq/jobs/remote.py b/src/enrgdaq/daq/jobs/remote.py index 8295d07..ba802c7 100644 --- a/src/enrgdaq/daq/jobs/remote.py +++ b/src/enrgdaq/daq/jobs/remote.py @@ -28,9 +28,10 @@ class DAQJobRemoteConfig(DAQJobConfig): topics (list[str]): List of topics to subscribe to. """ - zmq_local_url: str - zmq_remote_urls: list[str] + zmq_proxy_sub_urls: list[str] topics: list[str] = [] + use_xsub: bool = False + zmq_proxy_pub_url: Optional[str] = None class DAQJobRemote(DAQJob): @@ -46,10 +47,6 @@ class DAQJobRemote(DAQJob): config_type (type): Configuration type for the job. config (DAQJobRemoteConfig): Configuration instance. restart_offset (timedelta): Restart offset time. - _zmq_pub_ctx (zmq.Context): ZMQ context for publishing. - _zmq_sub_ctx (zmq.Context): ZMQ context for subscribing. - _zmq_pub (zmq.Socket): ZMQ socket for publishing. - _zmq_sub (Optional[zmq.Socket]): ZMQ socket for subscribing. _message_class_cache (dict): Cache for message classes. _remote_message_ids (set): Set of remote message IDs. _receive_thread (threading.Thread): Thread for receiving messages. @@ -60,11 +57,6 @@ class DAQJobRemote(DAQJob): config: DAQJobRemoteConfig restart_offset = timedelta(seconds=5) - _zmq_pub_ctx: zmq.Context - _zmq_sub_ctx: zmq.Context - - _zmq_pub: zmq.Socket - _zmq_sub: Optional[zmq.Socket] _message_class_cache: dict[str, type[DAQJobMessage]] _remote_message_ids: set[str] _receive_thread: threading.Thread @@ -72,15 +64,18 @@ class DAQJobRemote(DAQJob): def __init__(self, config: DAQJobRemoteConfig, **kwargs): super().__init__(config, **kwargs) - self._zmq_pub_ctx = zmq.Context() - self._logger.debug(f"Listening on {config.zmq_local_url}") - self._zmq_pub = self._zmq_pub_ctx.socket(zmq.PUB) - self._zmq_pub.bind(config.zmq_local_url) + if config.zmq_proxy_pub_url is not None: + self._zmq_pub_ctx = zmq.Context() + self._zmq_pub = self._zmq_pub_ctx.socket(zmq.PUB) + self._zmq_pub.connect(config.zmq_proxy_pub_url) + else: + self._zmq_pub_ctx = None + self._zmq_pub = None self._zmq_sub = None self._receive_thread = threading.Thread( target=self._start_receive_thread, - args=(config.zmq_remote_urls,), + args=(config.zmq_proxy_sub_urls,), daemon=True, ) self._message_class_cache = {} @@ -100,6 +95,8 @@ def handle_message(self, message: DAQJobMessage) -> bool: or not super().handle_message(message) # Ignore if the message is remote, meaning it was sent by another Supervisor or message.is_remote + # Ignore if we are not connected to the proxy + or self._zmq_pub is None ): return True # Silently ignore @@ -120,7 +117,7 @@ def _create_zmq_sub(self, remote_urls: list[str]) -> zmq.Socket: """ Create a ZMQ subscriber socket. - Args: + Args:g remote_urls (list[str]): List of remote URLs to connect to. Returns: @@ -157,6 +154,17 @@ def _start_receive_thread(self, remote_urls: list[str]): except zmq.ContextTerminated: break recv_message = self._unpack_message(message) + if ( + recv_message.daq_job_info is not None + and recv_message.daq_job_info.supervisor_config is not None + and self.info.supervisor_config is not None + and recv_message.daq_job_info.supervisor_config.supervisor_id + == self.info.supervisor_config.supervisor_id + ): + self._logger.warning( + f"Received own message '{type(recv_message).__name__}' on topic '{topic.decode()}', ignoring message. This should NOT happen. Check the config." + ) + continue self._logger.debug( f"Received {len(message)} bytes for message '{type(recv_message).__name__}' on topic '{topic.decode()}'" ) diff --git a/src/enrgdaq/daq/jobs/remote_proxy.py b/src/enrgdaq/daq/jobs/remote_proxy.py new file mode 100644 index 0000000..17a8ada --- /dev/null +++ b/src/enrgdaq/daq/jobs/remote_proxy.py @@ -0,0 +1,69 @@ +import zmq + +from enrgdaq.daq.base import DAQJob +from enrgdaq.daq.models import DAQJobConfig + + +class DAQJobRemoteProxyConfig(DAQJobConfig): + """ + Configuration for DAQJobRemoteProxy. + + Attributes: + zmq_xsub_url (str): ZMQ xsub URL. + zmq_xpub_url (str): ZMQ xpub URL. + """ + + zmq_xsub_url: str + zmq_xpub_url: str + + +class DAQJobRemoteProxy(DAQJob): + """ + DAQJobRemoteProxy is a DAQJob that acts as a proxy between two ZMQ sockets. + It uses zmq.proxy to forward messages between xsub and xpub. + + pub -> xsub -> xpub -> sub + + When you want to the DAQJobRemoteProxy: + - For pub, connect to xsub + - For sub, connect to xpub + + Attributes: + config_type (type): Configuration type for the job. + config (DAQJobRemoteProxyConfig): Configuration instance. + """ + + config_type = DAQJobRemoteProxyConfig + config: DAQJobRemoteProxyConfig + + def __init__(self, config: DAQJobRemoteProxyConfig, **kwargs): + super().__init__(config, **kwargs) + + self._zmq_ctx = zmq.Context() + self._xsub_sock = self._zmq_ctx.socket(zmq.XSUB) + self._xsub_sock.bind(config.zmq_xsub_url) + + self._xpub_sock = self._zmq_ctx.socket(zmq.XPUB) + self._xpub_sock.bind(config.zmq_xpub_url) + + self._logger.info( + f"Proxying between {config.zmq_xsub_url} -> {config.zmq_xpub_url}" + ) + + def start(self): + """ + Start the ZMQ proxy. + """ + try: + zmq.proxy(self._xsub_sock, self._xpub_sock) + except zmq.ContextTerminated: + pass + + def __del__(self): + """ + Destructor for DAQJobRemoteProxy. + """ + if getattr(self, "_zmq_ctx", None) is not None: + self._zmq_ctx.destroy() + + return super().__del__() diff --git a/src/tests/test_remote.py b/src/tests/test_remote.py index 9169116..9991822 100644 --- a/src/tests/test_remote.py +++ b/src/tests/test_remote.py @@ -21,8 +21,8 @@ def setUp(self, MockZmqContext): self.mock_receiver = self.mock_context.socket.return_value self.config = DAQJobRemoteConfig( daq_job_type="remote", - zmq_local_url="tcp://localhost:5555", - zmq_remote_urls=["tcp://localhost:5556"], + zmq_proxy_pub_url="tcp://localhost:5555", + zmq_proxy_sub_urls=["tcp://localhost:5556"], ) self.daq_job_remote = DAQJobRemote(self.config) self.daq_job_remote._zmq_pub = self.mock_sender diff --git a/src/tests/test_remote_proxy.py b/src/tests/test_remote_proxy.py new file mode 100644 index 0000000..1e94fc5 --- /dev/null +++ b/src/tests/test_remote_proxy.py @@ -0,0 +1,51 @@ +import unittest +from unittest.mock import patch + +import zmq + +from enrgdaq.daq.jobs.remote_proxy import DAQJobRemoteProxy, DAQJobRemoteProxyConfig + + +class TestDAQJobRemoteProxy(unittest.TestCase): + @patch("enrgdaq.daq.jobs.remote_proxy.zmq.Context") + def setUp(self, MockZmqContext): + self.mock_context = MockZmqContext.return_value + self.config = DAQJobRemoteProxyConfig( + daq_job_type="remote_proxy", + zmq_xsub_url="tcp://localhost:5557", + zmq_xpub_url="tcp://localhost:5558", + ) + self.daq_job_remote_proxy = DAQJobRemoteProxy(self.config) + + def test_initialization(self): + self.mock_context.socket.assert_any_call(zmq.XSUB) + self.mock_context.socket.assert_any_call(zmq.XPUB) + self.mock_context.socket.return_value.bind.assert_any_call( + "tcp://localhost:5557" + ) + self.mock_context.socket.return_value.bind.assert_any_call( + "tcp://localhost:5558" + ) + self.assertEqual(self.daq_job_remote_proxy.config, self.config) + self.assertEqual( + self.daq_job_remote_proxy._xsub_sock, self.mock_context.socket.return_value + ) + self.assertEqual( + self.daq_job_remote_proxy._xpub_sock, self.mock_context.socket.return_value + ) + + @patch("enrgdaq.daq.jobs.remote_proxy.zmq.proxy") + def test_start(self, mock_zmq_proxy): + self.daq_job_remote_proxy.start() + mock_zmq_proxy.assert_called_once_with( + self.daq_job_remote_proxy._xsub_sock, + self.daq_job_remote_proxy._xpub_sock, + ) + + def test_destructor(self): + del self.daq_job_remote_proxy + self.mock_context.destroy.assert_called_once() + + +if __name__ == "__main__": + unittest.main()