diff --git a/configs/examples/remote.toml b/configs/examples/remote.toml index 6980b57..7e13f32 100644 --- a/configs/examples/remote.toml +++ b/configs/examples/remote.toml @@ -1,3 +1,3 @@ daq_job_type = "remote" zmq_local_url = "tcp://localhost:10191" -zmq_remote_url = "tcp://1.2.3.4:10191" +zmq_remote_urls = ["tcp://1.2.3.4:10191"] diff --git a/src/daq/jobs/remote.py b/src/daq/jobs/remote.py index 3edff71..d029914 100644 --- a/src/daq/jobs/remote.py +++ b/src/daq/jobs/remote.py @@ -10,13 +10,13 @@ from daq.models import DAQJobConfig, DAQJobMessage from daq.store.models import DAQJobMessageStore -DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT = 1000 +DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT = 10000 @dataclass class DAQJobRemoteConfig(DAQJobConfig): zmq_local_url: str - zmq_remote_url: str + zmq_remote_urls: list[str] class DAQJobRemote(DAQJob): @@ -26,28 +26,40 @@ class DAQJobRemote(DAQJob): - message_in -> remote message_out - remote message_in -> message_out + + TODO: Use zmq CURVE security """ allowed_message_in_types = [DAQJobMessage] # accept all message types config_type = DAQJobRemoteConfig config: DAQJobRemoteConfig _zmq_local: zmq.Socket - _zmq_remote: zmq.Socket - _message_class_cache: dict + _zmq_remotes: dict[str, zmq.Socket] + _message_class_cache: dict[str, type[DAQJobMessage]] _remote_message_ids: set[str] + _receive_threads: dict[str, threading.Thread] def __init__(self, config: DAQJobRemoteConfig): super().__init__(config) self._zmq_context = zmq.Context() - self._zmq_local = self._zmq_context.socket(zmq.PUSH) - self._zmq_remote = self._zmq_context.socket(zmq.PULL) - self._zmq_local.connect(config.zmq_local_url) - self._zmq_remote.connect(config.zmq_remote_url) + self._logger.debug(f"Listening on {config.zmq_local_url}") + self._zmq_local = self._zmq_context.socket(zmq.PUB) + self._zmq_remotes = {} + self._zmq_local.bind(config.zmq_local_url) + + self._receive_threads = {} + for remote_url in config.zmq_remote_urls: + self._logger.debug(f"Connecting to {remote_url}") + zmq_remote = self._zmq_context.socket(zmq.SUB) + zmq_remote.connect(remote_url) + self._zmq_remotes[remote_url] = zmq_remote + self._receive_threads[remote_url] = threading.Thread( + target=self._start_receive_thread, + args=(remote_url, zmq_remote), + daemon=True, + ) self._message_class_cache = {} - self._receive_thread = threading.Thread( - target=self._start_receive_thread, daemon=True - ) self._message_class_cache = { x.__name__: x for x in DAQJobMessage.__subclasses__() } @@ -62,25 +74,26 @@ def handle_message(self, message: DAQJobMessage) -> bool: or message.id in self._remote_message_ids or not super().handle_message(message) ): - return False + return True # Silently ignore self._zmq_local.send(self._pack_message(message)) return True - def _start_receive_thread(self): + def _start_receive_thread(self, remote_url: str, zmq_remote: zmq.Socket): while True: - message = self._zmq_remote.recv() + message = zmq_remote.recv() self._logger.debug( - f"Received {len(message)} bytes from remote ({self.config.zmq_remote_url})" + f"Received {len(message)} bytes from remote ({remote_url})" ) # remote message_in -> message_out self.message_out.put(self._unpack_message(message)) def start(self): - self._receive_thread.start() + for remote_url in self._zmq_remotes.keys(): + self._receive_threads[remote_url].start() while True: - if not self._receive_thread.is_alive(): + if not any(x.is_alive() for x in self._receive_threads.values()): raise RuntimeError("Receive thread died") # message_in -> remote message_out self.consume() @@ -98,7 +111,7 @@ def _unpack_message(self, message: bytes) -> DAQJobMessage: message_class = self._message_class_cache[message_type] - res: DAQJobMessage = message_class.from_json(data) + res = message_class.from_json(data) if res.id is None: raise Exception("Message id is not set") @@ -107,3 +120,10 @@ def _unpack_message(self, message: bytes) -> DAQJobMessage: self._remote_message_ids.pop() self._logger.debug(f"Unpacked message {message_type} ({res.id})") return res + + def __del__(self): + for remote_url in self._zmq_remotes.keys(): + self._zmq_remotes[remote_url].close() + self._zmq_local.close() + + return super().__del__() diff --git a/src/tests/test_remote.py b/src/tests/test_remote.py index ce4f564..6aaae71 100644 --- a/src/tests/test_remote.py +++ b/src/tests/test_remote.py @@ -18,11 +18,11 @@ def setUp(self, MockZmqContext): self.config = DAQJobRemoteConfig( daq_job_type="remote", zmq_local_url="tcp://localhost:5555", - zmq_remote_url="tcp://localhost:5556", + zmq_remote_urls=["tcp://localhost:5556"], ) self.daq_job_remote = DAQJobRemote(self.config) self.daq_job_remote._zmq_local = self.mock_sender - self.daq_job_remote._zmq_remote = self.mock_receiver + self.daq_job_remote._zmq_remotes = {"tcp://localhost:5556": self.mock_receiver} def test_handle_message(self): message = DAQJobMessage( @@ -40,7 +40,9 @@ def stop_receive_thread(): time.sleep(0.1) mock_receive_thread.is_alive.return_value = False - self.daq_job_remote._receive_thread = mock_receive_thread + self.daq_job_remote._receive_threads = { + "tcp://localhost:5556": mock_receive_thread + } threading.Thread(target=stop_receive_thread, daemon=True).start() with self.assertRaises(RuntimeError): @@ -68,7 +70,9 @@ def side_effect(): self.mock_receiver.recv.side_effect = side_effect with self.assertRaises(RuntimeError): - self.daq_job_remote._start_receive_thread() + self.daq_job_remote._start_receive_thread( + "tcp://localhost:5556", self.mock_receiver + ) self.daq_job_remote.message_out.put.assert_called_once_with(message) self.assertEqual(self.daq_job_remote.message_out.put.call_count, 1) self.assertEqual(call_count, 2)