diff --git a/src/enrgdaq/daq/base.py b/src/enrgdaq/daq/base.py index e7e864f..ecb3fc6 100644 --- a/src/enrgdaq/daq/base.py +++ b/src/enrgdaq/daq/base.py @@ -71,6 +71,8 @@ def __init__( if supervisor_config is not None: self._supervisor_config = supervisor_config + else: + self._supervisor_config = None self.info = self._create_info() def consume(self, nowait=True): diff --git a/src/enrgdaq/daq/jobs/handle_stats.py b/src/enrgdaq/daq/jobs/handle_stats.py index ea7b245..dfb4b5c 100644 --- a/src/enrgdaq/daq/jobs/handle_stats.py +++ b/src/enrgdaq/daq/jobs/handle_stats.py @@ -1,7 +1,15 @@ -from datetime import datetime +from collections import defaultdict +from datetime import datetime, timedelta from typing import Dict, Optional +import msgspec + from enrgdaq.daq.base import DAQJob +from enrgdaq.daq.jobs.remote import ( + DAQJobMessageStatsRemote, + DAQJobMessageStatsRemoteDict, + SupervisorRemoteStats, +) from enrgdaq.daq.models import DAQJobMessage, DAQJobStats, DAQJobStatsRecord from enrgdaq.daq.store.models import ( DAQJobMessageStoreTabular, @@ -12,6 +20,7 @@ DAQJobStatsDict = Dict[type[DAQJob], DAQJobStats] DAQ_JOB_HANDLE_STATS_SLEEP_INTERVAL_SECONDS = 1 +DAQ_JOB_HANDLE_STATS_REMOTE_ALIVE_SECONDS = 30 class DAQJobHandleStatsConfig(StorableDAQJobConfig): @@ -34,24 +43,29 @@ class DAQJobHandleStats(DAQJob): It extracts relevant statistics from the messages and stores them. """ - allowed_message_in_types = [DAQJobMessageStats] + allowed_message_in_types = [DAQJobMessageStats, DAQJobMessageStatsRemote] config_type = DAQJobHandleStatsConfig config: DAQJobHandleStatsConfig _stats: dict[str, DAQJobStatsDict] + _remote_stats: dict[str, DAQJobMessageStatsRemoteDict] def __init__(self, config: DAQJobHandleStatsConfig, **kwargs): super().__init__(config, **kwargs) self._stats = {} + self._remote_stats = defaultdict() def start(self): while True: start_time = datetime.now() self.consume() self._save_stats() + self._save_remote_stats() sleep_for(DAQ_JOB_HANDLE_STATS_SLEEP_INTERVAL_SECONDS, start_time) - def handle_message(self, message: DAQJobMessageStats) -> bool: + def handle_message( + self, message: DAQJobMessageStats | DAQJobMessageStatsRemote + ) -> bool: if not super().handle_message(message): return False @@ -59,9 +73,14 @@ def handle_message(self, message: DAQJobMessageStats) -> bool: if not message.daq_job_info or not message.daq_job_info.supervisor_config: return True - self._stats[message.daq_job_info.supervisor_config.supervisor_id] = ( - message.stats - ) + if isinstance(message, DAQJobMessageStats): + self._stats[message.daq_job_info.supervisor_config.supervisor_id] = ( + message.stats + ) + elif isinstance(message, DAQJobMessageStatsRemote): + self._remote_stats[message.daq_job_info.supervisor_config.supervisor_id] = ( + message.stats + ) return True def _save_stats(self): @@ -109,3 +128,54 @@ def unpack_record(record: DAQJobStatsRecord): data=data_to_send, ) ) + + def _save_remote_stats(self): + keys = [ + "supervisor", + "is_alive", + "last_active", + "message_in_count", + "message_in_bytes", + "message_out_count", + "message_out_bytes", + ] + data_to_send = [] + + # Combine remote stats from all supervisors + remote_stats_combined = defaultdict(lambda: SupervisorRemoteStats()) + for _, remote_stats_dict in self._remote_stats.items(): + # For each remote stats dict, combine the values + for ( + supervisor_id, + remote_stats_dict_serialized_item, + ) in remote_stats_dict.items(): + # Convert the supervisor remote stats to a dict + remote_stats_dict_serialized = msgspec.structs.asdict( + remote_stats_dict_serialized_item + ) + for item, value in remote_stats_dict_serialized.items(): + setattr(remote_stats_combined[supervisor_id], item, value) + + for supervisor_id, remote_stats in remote_stats_combined.items(): + is_remote_alive = datetime.now() - remote_stats.last_active <= timedelta( + seconds=DAQ_JOB_HANDLE_STATS_REMOTE_ALIVE_SECONDS + ) + data_to_send.append( + [ + supervisor_id, + str(is_remote_alive).lower(), + remote_stats.last_active, + remote_stats.message_in_count, + remote_stats.message_in_bytes, + remote_stats.message_out_count, + remote_stats.message_out_bytes, + ] + ) + self._put_message_out( + DAQJobMessageStoreTabular( + store_config=self.config.store_config, + keys=keys, + data=data_to_send, + tag="remote", + ) + ) diff --git a/src/enrgdaq/daq/jobs/remote.py b/src/enrgdaq/daq/jobs/remote.py index 1b003dc..ab9ec3b 100644 --- a/src/enrgdaq/daq/jobs/remote.py +++ b/src/enrgdaq/daq/jobs/remote.py @@ -1,11 +1,12 @@ import pickle import threading -import time -from datetime import timedelta +from collections import defaultdict +from datetime import datetime, timedelta from typing import Optional import msgspec import zmq +from msgspec import Struct, field from enrgdaq.daq.base import DAQJob from enrgdaq.daq.models import ( @@ -13,8 +14,41 @@ DAQJobConfig, DAQJobMessage, ) +from enrgdaq.utils.time import sleep_for DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT = 10000 +DAQ_JOB_REMOTE_SLEEP_INTERVAL = 1 + + +class SupervisorRemoteStats(Struct): + """Statistics for a remote supervisor.""" + + message_in_count: int = 0 + message_in_bytes: int = 0 + + message_out_count: int = 0 + message_out_bytes: int = 0 + + last_active: datetime = field(default_factory=datetime.now) + + def update_message_in_stats(self, message_in_bytes: int): + self.message_in_count += 1 + self.message_in_bytes += message_in_bytes + self.last_active = datetime.now() + + def update_message_out_stats(self, message_out_bytes: int): + self.message_out_count += 1 + self.message_out_bytes += message_out_bytes + self.last_active = datetime.now() + + +class DAQJobMessageStatsRemote(DAQJobMessage): + """Message class containing remote statistics.""" + + stats: "DAQJobMessageStatsRemoteDict" + + +DAQJobMessageStatsRemoteDict = defaultdict[str, SupervisorRemoteStats] class DAQJobRemoteConfig(DAQJobConfig): @@ -59,6 +93,7 @@ class DAQJobRemote(DAQJob): _message_class_cache: dict[str, type[DAQJobMessage]] _remote_message_ids: set[str] _receive_thread: threading.Thread + _remote_stats: DAQJobMessageStatsRemoteDict def __init__(self, config: DAQJobRemoteConfig, **kwargs): super().__init__(config, **kwargs) @@ -83,6 +118,7 @@ def __init__(self, config: DAQJobRemoteConfig, **kwargs): x.__name__: x for x in DAQJobMessage.__subclasses__() } self._remote_message_ids = set() + self._remote_stats = defaultdict(lambda: SupervisorRemoteStats()) def handle_message(self, message: DAQJobMessage) -> bool: if ( @@ -101,10 +137,16 @@ def handle_message(self, message: DAQJobMessage) -> bool: return True remote_topic = message.remote_config.remote_topic or DEFAULT_REMOTE_TOPIC + remote_topic = remote_topic.encode() + packed_message = self._pack_message(message) + self._zmq_pub.send_multipart([remote_topic, packed_message]) + + # Update remote stats + if self._supervisor_config: + self._remote_stats[ + self._supervisor_config.supervisor_id + ].update_message_out_stats(len(packed_message) + len(remote_topic)) - self._zmq_pub.send_multipart( - [remote_topic.encode(), self._pack_message(message)] - ) self._logger.debug( f"Sent message '{type(message).__name__}' to topic '{remote_topic}'" ) @@ -169,6 +211,19 @@ def _start_receive_thread(self, remote_urls: list[str]): # remote message_in -> message_out self.message_out.put(recv_message) + # Update remote stats + if self._supervisor_config: + self._remote_stats[ + self._supervisor_config.supervisor_id + ].update_message_in_stats(len(message)) + if ( + recv_message.daq_job_info + and recv_message.daq_job_info.supervisor_config + ): + self._remote_stats[ + recv_message.daq_job_info.supervisor_config.supervisor_id + ].update_message_out_stats(len(message)) + def start(self): """ Start the receive thread and the DAQ job. @@ -176,11 +231,13 @@ def start(self): self._receive_thread.start() while True: + start_time = datetime.now() if not self._receive_thread.is_alive(): raise RuntimeError("Receive thread died") # message_in -> remote message_out self.consume() - time.sleep(0.1) + self._send_remote_stats_message() + sleep_for(DAQ_JOB_REMOTE_SLEEP_INTERVAL, start_time) def _pack_message(self, message: DAQJobMessage, use_pickle: bool = True) -> bytes: """ @@ -232,6 +289,9 @@ def _unpack_message(self, message: bytes) -> DAQJobMessage: self._remote_message_ids.pop() return res + def _send_remote_stats_message(self): + self._put_message_out(DAQJobMessageStatsRemote(self._remote_stats)) + def __del__(self): """ Destructor for DAQJobRemote.