diff --git a/src/daq/base.py b/src/daq/base.py index 81a17d5..00f7818 100644 --- a/src/daq/base.py +++ b/src/daq/base.py @@ -6,6 +6,9 @@ from daq.models import DAQJobMessage, DAQJobMessageStop, DAQJobStopError +daq_job_instance_id = 0 +daq_job_instance_id_lock = threading.Lock() + class DAQJob: allowed_message_in_types: list[type[DAQJobMessage]] = [] @@ -13,14 +16,22 @@ class DAQJob: config: Any message_in: Queue[DAQJobMessage] message_out: Queue[DAQJobMessage] + instance_id: int _logger: logging.Logger def __init__(self, config: Any): + global daq_job_instance_id, daq_job_instance_id_lock + + with daq_job_instance_id_lock: + self.instance_id = daq_job_instance_id + daq_job_instance_id += 1 + self._logger = logging.getLogger(f"{type(self).__name__}({self.instance_id})") + self.config = config self.message_in = Queue() self.message_out = Queue() - self._logger = logging.getLogger(type(self).__name__) + self._should_stop = False def consume(self): diff --git a/src/daq/daq_job.py b/src/daq/daq_job.py index 224b691..48f8ec2 100644 --- a/src/daq/daq_job.py +++ b/src/daq/daq_job.py @@ -47,6 +47,14 @@ def start_daq_job(daq_job: DAQJob) -> DAQJobThread: return DAQJobThread(daq_job, thread) +def restart_daq_job(daq_job: DAQJob) -> DAQJobThread: + logging.info(f"Restarting {type(daq_job).__name__}") + new_daq_job = type(daq_job)(daq_job.config) + thread = threading.Thread(target=new_daq_job.start, daemon=True) + thread.start() + return DAQJobThread(new_daq_job, thread) + + def start_daq_jobs(daq_jobs: list[DAQJob]) -> list[DAQJobThread]: threads = [] for daq_job in daq_jobs: diff --git a/src/main.py b/src/main.py index 8bc9a77..060e53e 100644 --- a/src/main.py +++ b/src/main.py @@ -6,7 +6,12 @@ from daq.alert.base import DAQJobAlert from daq.base import DAQJob, DAQJobThread -from daq.daq_job import load_daq_jobs, parse_store_config, start_daq_job, start_daq_jobs +from daq.daq_job import ( + load_daq_jobs, + parse_store_config, + restart_daq_job, + start_daq_jobs, +) from daq.jobs.handle_stats import DAQJobMessageStats, DAQJobStatsDict from daq.models import DAQJobMessage, DAQJobStats from daq.store.base import DAQJobStore @@ -31,7 +36,7 @@ def loop( # Restart jobs that have stopped for thread in dead_threads: - daq_job_threads.append(start_daq_job(thread.daq_job)) + daq_job_threads.append(restart_daq_job(thread.daq_job)) # Update restart stats get_daq_job_stats(daq_job_stats, type(thread.daq_job)).restart_stats.increase() diff --git a/src/tests/test_main.py b/src/tests/test_main.py index aa214c7..6901e32 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -29,7 +29,8 @@ def test_start_daq_job_threads(self, mock_load_daq_jobs, mock_start_daq_jobs): self.assertEqual(result, ["thread1", "thread2"]) @patch("main.start_daq_job") - def test_loop(self, mock_start_daq_job): + @patch("main.restart_daq_job") + def test_loop(self, mock_start_daq_job, mock_restart_daq_job): RUN_COUNT = 3 mock_thread_alive = MagicMock(name="thread_alive") @@ -57,6 +58,7 @@ def test_loop(self, mock_start_daq_job): mock_thread_alive.daq_job.message_out.put(mock_store_message) mock_start_daq_job.return_value = mock_thread_dead + mock_restart_daq_job.return_value = mock_thread_store daq_job_threads = [mock_thread_alive, mock_thread_dead, mock_thread_store] daq_job_threads: list[DAQJobThread] = daq_job_threads