Skip to content

Commit

Permalink
refactor: move get_daq_job_class to daq.types and update DAQJobHeal…
Browse files Browse the repository at this point in the history
…thcheck using that
  • Loading branch information
furkan-bilgin committed Nov 8, 2024
1 parent 6d86cea commit eaa4ef1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 22 deletions.
19 changes: 4 additions & 15 deletions src/daq/daq_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,14 @@

from daq.base import DAQJob, DAQJobThread
from daq.models import DAQJobConfig
from daq.types import DAQ_JOB_TYPE_TO_CLASS
from utils.subclasses import all_subclasses

ALL_DAQ_JOBS = all_subclasses(DAQJob)
from daq.types import get_daq_job_class


def build_daq_job(toml_config: bytes) -> DAQJob:
generic_daq_job_config = msgspec.toml.decode(toml_config, type=DAQJobConfig)
daq_job_class = None

if generic_daq_job_config.daq_job_type in DAQ_JOB_TYPE_TO_CLASS:
daq_job_class = DAQ_JOB_TYPE_TO_CLASS[generic_daq_job_config.daq_job_type]
logging.warning(
f"DAQ job type '{generic_daq_job_config.daq_job_type}' is deprecated, please use the '{daq_job_class.__name__}' instead"
)
else:
for daq_job in ALL_DAQ_JOBS:
if daq_job.__name__ == generic_daq_job_config.daq_job_type:
daq_job_class = daq_job
daq_job_class = get_daq_job_class(
generic_daq_job_config.daq_job_type, warn_deprecated=True
)

if daq_job_class is None:
raise Exception(f"Invalid DAQ job type: {generic_daq_job_config.daq_job_type}")
Expand Down
15 changes: 8 additions & 7 deletions src/daq/jobs/healthcheck.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from datetime import datetime, timedelta
from enum import Enum
from typing import Optional
from typing import Callable, Optional

import msgspec
from msgspec import Struct
Expand Down Expand Up @@ -64,26 +64,27 @@ class DAQJobHealthcheck(DAQJob):
_daq_job_type_to_class: dict[str, type[DAQJob]]

_healthcheck_stats: list[HealthcheckStatsItem]
_get_daq_job_class: Callable[[str], Optional[type[DAQJob]]]

def __init__(self, config: DAQJobHealthcheckConfig):
from daq.types import DAQ_JOB_TYPE_TO_CLASS
from daq.types import ALL_DAQ_JOBS, get_daq_job_class

self._daq_job_type_to_class = DAQ_JOB_TYPE_TO_CLASS
self._get_daq_job_class = get_daq_job_class
self._current_stats = {}

super().__init__(config)

self._healthcheck_stats = []

if config.enable_alerts_on_restart:
for daq_job_type, daq_job_type_class in self._daq_job_type_to_class.items():
for daq_job_type_class in ALL_DAQ_JOBS:
self._healthcheck_stats.append(
HealthcheckStatsItem(
alert_info=DAQAlertInfo(
message=f"{daq_job_type_class.__name__} crashed and got restarted!",
severity=DAQAlertSeverity.ERROR,
),
daq_job_type=daq_job_type,
daq_job_type=daq_job_type_class.__name__,
alert_if_interval_is=AlertCondition.SATISFIED,
stats_key="restart_stats",
interval="1m",
Expand All @@ -100,7 +101,7 @@ def __init__(self, config: DAQJobHealthcheckConfig):
)
if item.stats_key not in DAQJobStats.__annotations__.keys():
raise ValueError(f"Invalid stats key: {item.stats_key}")
if item.daq_job_type not in self._daq_job_type_to_class:
if self._get_daq_job_class(item.daq_job_type) is None:
raise ValueError(f"Invalid DAQ job type: {item.daq_job_type}")
if item.interval is None and item.amount is None:
raise ValueError("interval or amount must be specified")
Expand Down Expand Up @@ -131,7 +132,7 @@ def handle_checks(self):

for item in self._healthcheck_stats:
# Get the current DAQJobStats by daq_job_type of item
item_daq_job_type = self._daq_job_type_to_class[item.daq_job_type]
item_daq_job_type = self._get_daq_job_class(item.daq_job_type)
if item_daq_job_type not in self._current_stats:
continue

Expand Down
24 changes: 24 additions & 0 deletions src/daq/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import logging
from typing import Optional

from daq.alert.alert_slack import DAQJobAlertSlack
from daq.base import DAQJob
from daq.jobs.caen.n1081b import DAQJobN1081B
Expand All @@ -9,6 +12,10 @@
from daq.jobs.store.csv import DAQJobStoreCSV
from daq.jobs.store.root import DAQJobStoreROOT
from daq.jobs.test_job import DAQJobTest
from utils.subclasses import all_subclasses

ALL_DAQ_JOBS = all_subclasses(DAQJob)


DAQ_JOB_TYPE_TO_CLASS: dict[str, type[DAQJob]] = {
"n1081b": DAQJobN1081B,
Expand All @@ -22,3 +29,20 @@
"healthcheck": DAQJobHealthcheck,
"remote": DAQJobRemote,
}


def get_daq_job_class(
daq_job_type: str, warn_deprecated: bool = False
) -> Optional[type[DAQJob]]:
daq_job_class = None
if daq_job_type in DAQ_JOB_TYPE_TO_CLASS:
daq_job_class = DAQ_JOB_TYPE_TO_CLASS[daq_job_type]
if warn_deprecated:
logging.warning(
f"DAQ job type '{daq_job_type}' is deprecated, please use the '{daq_job_class.__name__}' instead"
)
else:
for daq_job in ALL_DAQ_JOBS:
if daq_job.__name__ == daq_job_type:
daq_job_class = daq_job
return daq_job_class

0 comments on commit eaa4ef1

Please sign in to comment.