Skip to content

Commit

Permalink
feat: use pickle by default for DAQJobRemote
Browse files Browse the repository at this point in the history
- it *is* unsafe, but has less overhead and is more straightforward, also we still fallback to json so it is possible to interop with theoretical C++ ENRGDAQ programs if needed in the future
- add `store_config` to tests
  • Loading branch information
furkan-bilgin committed Oct 22, 2024
1 parent 73ed894 commit 1a5b048
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
24 changes: 17 additions & 7 deletions src/daq/jobs/remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import pickle
import threading
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -99,22 +100,31 @@ def start(self):
self.consume()
time.sleep(0.1)

def _pack_message(self, message: DAQJobMessage) -> bytes:
def _pack_message(self, message: DAQJobMessage, use_pickle: bool = True) -> bytes:
if use_pickle:
return pickle.dumps(message)

message_type = type(message).__name__
self._logger.debug(f"Packing message {message_type} ({message.id})")
return json.dumps([message_type, message.to_json()]).encode("utf-8")

def _unpack_message(self, message: bytes) -> DAQJobMessage:
message_type, data = json.loads(message.decode("utf-8"))
if message_type not in self._message_class_cache:
raise Exception(f"Invalid message type: {message_type}")
try:
res = pickle.loads(message)
if not isinstance(res, DAQJobMessage):
raise Exception("Message is not DAQJobMessage")
message_type = type(res).__name__
except pickle.UnpicklingError:
message_type, data = json.loads(message.decode("utf-8"))
if message_type not in self._message_class_cache:
raise Exception(f"Invalid message type: {message_type}")

message_class = self._message_class_cache[message_type]

message_class = self._message_class_cache[message_type]
res = message_class.from_json(data)

res = message_class.from_json(data)
if res.id is None:
raise Exception("Message id is not set")

self._remote_message_ids.add(res.id)
if len(self._remote_message_ids) > DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT:
self._remote_message_ids.pop()
Expand Down
1 change: 1 addition & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def send_messages_to_daq_jobs(
daq_job_stats: DAQJobStatsDict,
):
for message in daq_messages:
# TODO: Make this into a generalized interface
if isinstance(message, DAQJobMessageStore) and isinstance(
message.store_config, dict
):
Expand Down
5 changes: 4 additions & 1 deletion src/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import MagicMock, patch

from daq.jobs.remote import DAQJobRemote, DAQJobRemoteConfig
from daq.jobs.store.csv import DAQJobStoreConfigCSV
from daq.jobs.test_job import DAQJobTest
from daq.models import DAQJobMessage
from daq.store.models import DAQJobMessageStore
Expand Down Expand Up @@ -51,7 +52,9 @@ def stop_receive_thread():
def test_receive_thread(self):
message = DAQJobMessageStore(
id="testmsg",
store_config={},
store_config=DAQJobStoreConfigCSV(
daq_job_store_type="csv", file_path="test", add_date=True
),
data=[],
keys=[],
daq_job_info=DAQJobTest({"daq_job_type": "test"}).get_info(),
Expand Down

0 comments on commit 1a5b048

Please sign in to comment.