Skip to content

Commit

Permalink
feat: add pyzmq, DAQJobRemote and tests for it
Browse files Browse the repository at this point in the history
- `DAQJobRemote` is a DAQJob that connects two seperate ENRGDAQ instances
  • Loading branch information
furkan-bilgin committed Oct 19, 2024
1 parent e9700c1 commit 8d82dc9
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies = [
"coloredlogs>=15.0.1",
"dataclasses-json>=0.6.7",
"n1081b-sdk",
"pyzmq>=26.2.0",
"slack-webhook>=1.0.7",
"uproot>=5.4.1",
]
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# uv export --no-hashes --output-file=requirements.txt
awkward==2.6.9
awkward-cpp==39
cffi==1.17.1 ; implementation_name == 'pypy'
cfgv==3.4.0
coloredlogs==15.0.1
cramjam==2.8.4
Expand All @@ -19,8 +20,10 @@ numpy==2.1.2
packaging==24.1
platformdirs==4.3.6
pre-commit==4.0.1
pycparser==2.22 ; implementation_name == 'pypy'
pyreadline3==3.5.4 ; sys_platform == 'win32'
pyyaml==6.0.2
pyzmq==26.2.0
ruff==0.6.9
slack-webhook==1.0.7
typing-extensions==4.12.2
Expand Down
60 changes: 60 additions & 0 deletions src/daq/jobs/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pickle
import threading
import time
from dataclasses import dataclass

import zmq

from daq.base import DAQJob
from daq.models import DAQJobConfig, DAQJobMessage


@dataclass
class DAQJobRemoteConfig(DAQJobConfig):
zmq_sender_url: str
zmq_receiver_url: str


class DAQJobRemote(DAQJob):
"""
DAQJobRemote is a DAQJob that connects two seperate ENRGDAQ instances.
It sends to and receives from a remote ENRGDAQ, in such that:
- message_in -> remote message_out
- remote message_in -> message_out
"""

allowed_message_in_types = [DAQJobMessage] # accept all message types
config = DAQJobRemoteConfig

def __init__(self, config: DAQJobRemoteConfig):
super().__init__(config)
self._zmq_context = zmq.Context()
self._zmq_sender = self._zmq_context.socket(zmq.PUSH)
self._zmq_receiver = self._zmq_context.socket(zmq.PULL)
self._zmq_sender.connect(config.zmq_sender_url)
self._zmq_receiver.connect(config.zmq_receiver_url)

self._receive_thread = threading.Thread(
target=self._start_receive_thread, daemon=True
)

def handle_message(self, message: DAQJobMessage) -> bool:
self._zmq_sender.send(pickle.dumps(message))
return True

def _start_receive_thread(self):
while True:
message = self._zmq_receiver.recv()
# remote message_in -> message_out
self.message_out.put(pickle.loads(message))

def start(self):
self._receive_thread.start()

while True:
if not self._receive_thread.is_alive():
raise RuntimeError("receive thread died")
# message_in -> remote message_out
self.consume()
time.sleep(0.1)
13 changes: 11 additions & 2 deletions src/daq/store/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional

from dataclasses_json import DataClassJsonMixin

Expand All @@ -19,11 +19,20 @@ class DAQJobStoreConfig(DataClassJsonMixin):
@dataclass
class DAQJobMessageStore(DAQJobMessage):
store_config: dict | DAQJobStoreConfig
daq_job: DAQJob
daq_job: Optional[DAQJob]
keys: list[str]
data: list[list[Any]]
prefix: str | None = None

def __getstate__(self):
state = self.__dict__.copy()
del state["daq_job"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.daq_job = None # type: ignore


@dataclass
class StorableDAQJobConfig(DAQJobConfig):
Expand Down
2 changes: 2 additions & 0 deletions src/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tests.test_healthcheck import TestDAQJobHealthcheck
from tests.test_main import TestMain
from tests.test_n1081b import TestDAQJobN1081B
from tests.test_remote import TestDAQJobRemote
from tests.test_slack import TestDAQJobAlertSlack


Expand All @@ -21,6 +22,7 @@ def run_tests():
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobAlertSlack))
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobHealthcheck))
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobHandleAlerts))
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobRemote))
return test_suite


Expand Down
70 changes: 70 additions & 0 deletions src/tests/test_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pickle
import threading
import time
import unittest
from unittest.mock import MagicMock, patch

from daq.jobs.remote import DAQJobRemote, DAQJobRemoteConfig
from daq.jobs.test_job import DAQJobTest
from daq.store.models import DAQJobMessageStore


class TestDAQJobRemote(unittest.TestCase):
@patch("daq.jobs.remote.zmq.Context")
def setUp(self, MockZmqContext):
self.mock_context = MockZmqContext.return_value
self.mock_sender = self.mock_context.socket.return_value
self.mock_receiver = self.mock_context.socket.return_value
self.config = DAQJobRemoteConfig(
daq_job_type="remote",
zmq_sender_url="tcp://localhost:5555",
zmq_receiver_url="tcp://localhost:5556",
)
self.daq_job_remote = DAQJobRemote(self.config)
self.daq_job_remote._zmq_sender = self.mock_sender
self.daq_job_remote._zmq_receiver = self.mock_receiver

def test_handle_message(self):
message = DAQJobMessageStore(
store_config={}, data=[], keys=[], daq_job=DAQJobTest({})
)
self.daq_job_remote.handle_message(message)
self.mock_sender.send.assert_called_once_with(pickle.dumps(message))

def test_start(self):
mock_receive_thread = MagicMock()

def stop_receive_thread():
time.sleep(0.1)
mock_receive_thread.is_alive.return_value = False

self.daq_job_remote._receive_thread = mock_receive_thread
threading.Thread(target=stop_receive_thread, daemon=True).start()

with self.assertRaises(RuntimeError):
self.daq_job_remote.start()

def test_receive_thread(self):
message = DAQJobMessageStore(store_config={}, data=[], keys=[], daq_job=None) # type: ignore
self.daq_job_remote.message_out = MagicMock()

call_count = 0

def side_effect():
nonlocal call_count
call_count += 1
if call_count >= 2:
raise Exception("Stop receive thread")
return pickle.dumps(message)

self.mock_receiver.recv.side_effect = side_effect

with self.assertRaises(Exception):
self.daq_job_remote._start_receive_thread()
self.daq_job_remote.message_out.put.assert_called_once_with(message)
self.assertEqual(self.daq_job_remote.message_out.put.call_count, 1)
self.assertEqual(call_count, 2)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 8d82dc9

Please sign in to comment.