Skip to content

Commit

Permalink
feat: migrate from dataclasses-json to msgspec
Browse files Browse the repository at this point in the history
  • Loading branch information
furkan-bilgin committed Nov 7, 2024
1 parent 2da1b05 commit f37c417
Show file tree
Hide file tree
Showing 24 changed files with 90 additions and 153 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.12"
dependencies = [
"caen-libs>=1.1.0",
"coloredlogs>=15.0.1",
"dataclasses-json>=0.6.7",
"msgspec>=0.18.6",
"n1081b-sdk",
"pyzmq>=26.2.0",
"slack-webhook>=1.0.7",
Expand Down
6 changes: 1 addition & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ cffi==1.17.1 ; implementation_name == 'pypy'
cfgv==3.4.0
coloredlogs==15.0.1
cramjam==2.8.4
dataclasses-json==0.6.7
distlib==0.3.9
filelock==3.16.1
fsspec==2024.9.0
humanfriendly==10.0
identify==2.6.1
marshmallow==3.22.0
mypy-extensions==1.0.0
msgspec==0.18.6
n1081b-sdk @ git+https://github.com/ENRG-tr/N1081B-SDK-Python.git@3217fb0fdb5e5997116c687b1d15bed174fc0400
nodeenv==1.9.1
numpy==2.1.2
Expand All @@ -27,8 +25,6 @@ pyyaml==6.0.2
pyzmq==26.2.0
ruff==0.6.9
slack-webhook==1.0.7
typing-extensions==4.12.2
typing-inspect==0.9.0
uproot==5.4.1
virtualenv==20.26.6
websocket-client==1.8.0
Expand Down
3 changes: 0 additions & 3 deletions src/daq/alert/alert_slack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dataclasses import dataclass

from slack_webhook import Slack

from daq.alert.base import DAQAlertSeverity, DAQJobAlert, DAQJobMessageAlert
Expand All @@ -12,7 +10,6 @@
}


@dataclass
class DAQJobAlertSlackConfig(DAQJobConfig):
slack_webhook_url: str

Expand Down
7 changes: 2 additions & 5 deletions src/daq/alert/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any

from dataclasses_json import DataClassJsonMixin
from msgspec import Struct

from daq.base import DAQJob, DAQJobInfo
from daq.models import DAQJobMessage
Expand All @@ -16,13 +15,11 @@ class DAQAlertSeverity(str, Enum):
ERROR = "error"


@dataclass
class DAQAlertInfo(DataClassJsonMixin):
class DAQAlertInfo(Struct):
message: str
severity: DAQAlertSeverity


@dataclass
class DAQJobMessageAlert(DAQJobMessage):
daq_job_info: DAQJobInfo
date: datetime
Expand Down
9 changes: 9 additions & 0 deletions src/daq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,12 @@ class DAQJobInfo:
daq_job_class_name: str # has type(self).__name__
unique_id: str
instance_id: int

@staticmethod
def mock() -> "DAQJobInfo":
return DAQJobInfo(
daq_job_type="mock",
daq_job_class_name="mock",
unique_id="mock",
instance_id=0,
)
14 changes: 7 additions & 7 deletions src/daq/daq_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import os
import threading

import tomllib
import msgspec

from daq.base import DAQJob, DAQJobThread
from daq.models import DAQJobConfig
from daq.store.models import DAQJobStoreConfig
from daq.types import DAQ_JOB_TYPE_TO_CLASS


def build_daq_job(toml_config: dict) -> DAQJob:
generic_daq_job_config = DAQJobConfig.from_dict(toml_config)
def build_daq_job(toml_config: bytes) -> DAQJob:
generic_daq_job_config = msgspec.toml.decode(toml_config, type=DAQJobConfig)

if generic_daq_job_config.daq_job_type not in DAQ_JOB_TYPE_TO_CLASS:
raise Exception(f"Invalid DAQ job type: {generic_daq_job_config.daq_job_type}")
Expand All @@ -22,7 +22,7 @@ def build_daq_job(toml_config: dict) -> DAQJob:
daq_job_config_class: DAQJobConfig = daq_job_class.config_type

# Load the config in
config = daq_job_config_class.schema().load(toml_config)
config = msgspec.toml.decode(toml_config, type=daq_job_config_class)

return daq_job_class(config)

Expand All @@ -32,9 +32,9 @@ def load_daq_jobs(job_config_dir: str) -> list[DAQJob]:
job_files = glob.glob(os.path.join(job_config_dir, "*.toml"))
for job_file in job_files:
with open(job_file, "rb") as f:
job_config = tomllib.load(f)
job_config_raw = f.read()

jobs.append(build_daq_job(job_config))
jobs.append(build_daq_job(job_config_raw))

return jobs

Expand Down Expand Up @@ -72,4 +72,4 @@ def parse_store_config(config: dict) -> DAQJobStoreConfig:
daq_job_store_type = config["daq_job_store_type"]
store_config_class = DAQ_STORE_CONFIG_TYPE_TO_CLASS[daq_job_store_type]

return store_config_class.schema().load(config)
return store_config_class(**config)
2 changes: 0 additions & 2 deletions src/daq/jobs/caen/n1081b.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from dataclasses import dataclass

from N1081B import N1081B
from websocket import WebSocket, create_connection
Expand All @@ -13,7 +12,6 @@
N1081B_WEBSOCKET_TIMEOUT_SECONDS = 5


@dataclass
class DAQJobN1081BConfig(StorableDAQJobConfig):
host: str
port: str
Expand Down
2 changes: 0 additions & 2 deletions src/daq/jobs/handle_alerts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import time
from dataclasses import dataclass

from daq.alert.base import DAQJobMessageAlert
from daq.base import DAQJob
from daq.store.models import DAQJobMessageStore, StorableDAQJobConfig
from utils.time import get_unix_timestamp_ms


@dataclass
class DAQJobHandleAlertsConfig(StorableDAQJobConfig):
pass

Expand Down
3 changes: 0 additions & 3 deletions src/daq/jobs/handle_stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Optional

Expand All @@ -11,12 +10,10 @@
DAQJobStatsDict = Dict[type[DAQJob], DAQJobStats]


@dataclass
class DAQJobHandleStatsConfig(StorableDAQJobConfig):
pass


@dataclass
class DAQJobMessageStats(DAQJobMessage):
stats: DAQJobStatsDict

Expand Down
11 changes: 4 additions & 7 deletions src/daq/jobs/healthcheck.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import Optional

from dataclasses_json import DataClassJsonMixin
import msgspec
from msgspec import Struct

from daq.alert.base import DAQAlertInfo, DAQAlertSeverity, DAQJobMessageAlert
from daq.base import DAQJob
Expand All @@ -17,12 +17,10 @@ class AlertCondition(str, Enum):
UNSATISFIED = "unsatisfied"


@dataclass
class HealthcheckItem(DataClassJsonMixin):
class HealthcheckItem(Struct):
alert_info: DAQAlertInfo


@dataclass
class HealthcheckStatsItem(HealthcheckItem):
daq_job_type: str
stats_key: str
Expand Down Expand Up @@ -50,7 +48,6 @@ def parse_interval(self) -> timedelta:
raise ValueError(f"Invalid interval unit: {unit}")


@dataclass
class DAQJobHealthcheckConfig(DAQJobConfig):
healthcheck_stats: list[HealthcheckStatsItem]
enable_alerts_on_restart: bool = True
Expand Down Expand Up @@ -157,7 +154,7 @@ def handle_checks(self):

# Alert if it's new
for item, should_alert in res:
item_id = hash(item.to_json())
item_id = hash(msgspec.json.encode(item))
if should_alert and item_id not in self._sent_alert_items:
self._sent_alert_items.add(item_id)
self.send_alert(item)
Expand Down
12 changes: 5 additions & 7 deletions src/daq/jobs/remote.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
import pickle
import threading
import time
from dataclasses import dataclass

import msgspec
import zmq

from daq.base import DAQJob
Expand All @@ -13,7 +12,6 @@
DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT = 10000


@dataclass
class DAQJobRemoteConfig(DAQJobConfig):
zmq_local_url: str
zmq_remote_urls: list[str]
Expand Down Expand Up @@ -107,22 +105,22 @@ def _pack_message(self, message: DAQJobMessage, use_pickle: bool = True) -> byte
if use_pickle:
return pickle.dumps(message, protocol=pickle.HIGHEST_PROTOCOL)

return json.dumps([message_type, message.to_json()]).encode("utf-8")
return msgspec.msgpack.encode([message_type, message])

def _unpack_message(self, message: bytes) -> DAQJobMessage:
# TODO: fix unpack without pickle
try:
res = pickle.loads(message)
if not isinstance(res, DAQJobMessage):
raise Exception("Message is not DAQJobMessage")
message_type = type(res).__name__
except pickle.UnpicklingError:
message_type, data = json.loads(message.decode("utf-8"))
message_type, data = msgspec.msgpack.decode(message)
if message_type not in self._message_class_cache:
raise Exception(f"Invalid message type: {message_type}")

message_class = self._message_class_cache[message_type]

res = message_class.from_json(data)
res = message_class(**data)

if res.id is None:
raise Exception("Message id is not set")
Expand Down
2 changes: 0 additions & 2 deletions src/daq/jobs/serve_http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import http.server
import threading
from dataclasses import dataclass
from datetime import datetime, timedelta

from daq.base import DAQJob
Expand All @@ -13,7 +12,6 @@
from socketserver import ThreadingMixIn as ForkingMixIn


@dataclass
class DAQJobServeHTTPConfig(DAQJobConfig):
serve_path: str
host: str
Expand Down
2 changes: 0 additions & 2 deletions src/daq/jobs/store/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
DAQ_JOB_STORE_CSV_WRITE_BATCH_SIZE = 1000


@dataclass
class DAQJobStoreConfigCSV(DAQJobStoreConfig):
file_path: str
add_date: bool
overwrite: Optional[bool] = None


@dataclass
class DAQJobStoreCSVConfig(DAQJobConfig):
out_dir: str = "out/"

Expand Down
3 changes: 0 additions & 3 deletions src/daq/jobs/store/root.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from dataclasses import dataclass
from typing import Any, cast

import uproot
Expand All @@ -10,13 +9,11 @@
from utils.file import modify_file_path


@dataclass
class DAQJobStoreConfigROOT(DAQJobStoreConfig):
file_path: str
add_date: bool


@dataclass
class DAQJobStoreROOTConfig(DAQJobConfig):
pass

Expand Down
2 changes: 0 additions & 2 deletions src/daq/jobs/test_job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from dataclasses import dataclass
from random import randint

from N1081B import N1081B
Expand All @@ -8,7 +7,6 @@
from daq.store.models import DAQJobMessageStore, StorableDAQJobConfig


@dataclass
class DAQJobTestConfig(StorableDAQJobConfig):
rand_min: int
rand_max: int
Expand Down
Loading

0 comments on commit f37c417

Please sign in to comment.