Skip to content

Commit

Permalink
feat: combine remote stats
Browse files Browse the repository at this point in the history
  • Loading branch information
furkan-bilgin committed Dec 4, 2024
1 parent ae6a449 commit b1ee017
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/enrgdaq/daq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(

if supervisor_config is not None:
self._supervisor_config = supervisor_config
else:
self._supervisor_config = None
self.info = self._create_info()

def consume(self, nowait=True):
Expand Down
82 changes: 76 additions & 6 deletions src/enrgdaq/daq/jobs/handle_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from datetime import datetime
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, Optional

import msgspec

from enrgdaq.daq.base import DAQJob
from enrgdaq.daq.jobs.remote import (
DAQJobMessageStatsRemote,
DAQJobMessageStatsRemoteDict,
SupervisorRemoteStats,
)
from enrgdaq.daq.models import DAQJobMessage, DAQJobStats, DAQJobStatsRecord
from enrgdaq.daq.store.models import (
DAQJobMessageStoreTabular,
Expand All @@ -12,6 +20,7 @@
DAQJobStatsDict = Dict[type[DAQJob], DAQJobStats]

DAQ_JOB_HANDLE_STATS_SLEEP_INTERVAL_SECONDS = 1
DAQ_JOB_HANDLE_STATS_REMOTE_ALIVE_SECONDS = 30


class DAQJobHandleStatsConfig(StorableDAQJobConfig):
Expand All @@ -34,34 +43,44 @@ class DAQJobHandleStats(DAQJob):
It extracts relevant statistics from the messages and stores them.
"""

allowed_message_in_types = [DAQJobMessageStats]
allowed_message_in_types = [DAQJobMessageStats, DAQJobMessageStatsRemote]
config_type = DAQJobHandleStatsConfig
config: DAQJobHandleStatsConfig

_stats: dict[str, DAQJobStatsDict]
_remote_stats: dict[str, DAQJobMessageStatsRemoteDict]

def __init__(self, config: DAQJobHandleStatsConfig, **kwargs):
super().__init__(config, **kwargs)
self._stats = {}
self._remote_stats = defaultdict()

def start(self):
while True:
start_time = datetime.now()
self.consume()
self._save_stats()
self._save_remote_stats()
sleep_for(DAQ_JOB_HANDLE_STATS_SLEEP_INTERVAL_SECONDS, start_time)

def handle_message(self, message: DAQJobMessageStats) -> bool:
def handle_message(
self, message: DAQJobMessageStats | DAQJobMessageStatsRemote
) -> bool:
if not super().handle_message(message):
return False

# Ignore if the message has no supervisor info
if not message.daq_job_info or not message.daq_job_info.supervisor_config:
return True

self._stats[message.daq_job_info.supervisor_config.supervisor_id] = (
message.stats
)
if isinstance(message, DAQJobMessageStats):
self._stats[message.daq_job_info.supervisor_config.supervisor_id] = (
message.stats
)
elif isinstance(message, DAQJobMessageStatsRemote):
self._remote_stats[message.daq_job_info.supervisor_config.supervisor_id] = (
message.stats
)
return True

def _save_stats(self):
Expand Down Expand Up @@ -109,3 +128,54 @@ def unpack_record(record: DAQJobStatsRecord):
data=data_to_send,
)
)

def _save_remote_stats(self):
keys = [
"supervisor",
"is_alive",
"last_active",
"message_in_count",
"message_in_bytes",
"message_out_count",
"message_out_bytes",
]
data_to_send = []

# Combine remote stats from all supervisors
remote_stats_combined = defaultdict(lambda: SupervisorRemoteStats())
for _, remote_stats_dict in self._remote_stats.items():
# For each remote stats dict, combine the values
for (
supervisor_id,
remote_stats_dict_serialized_item,
) in remote_stats_dict.items():
# Convert the supervisor remote stats to a dict
remote_stats_dict_serialized = msgspec.structs.asdict(
remote_stats_dict_serialized_item
)
for item, value in remote_stats_dict_serialized.items():
setattr(remote_stats_combined[supervisor_id], item, value)

for supervisor_id, remote_stats in remote_stats_combined.items():
is_remote_alive = datetime.now() - remote_stats.last_active <= timedelta(
seconds=DAQ_JOB_HANDLE_STATS_REMOTE_ALIVE_SECONDS
)
data_to_send.append(
[
supervisor_id,
str(is_remote_alive).lower(),
remote_stats.last_active,
remote_stats.message_in_count,
remote_stats.message_in_bytes,
remote_stats.message_out_count,
remote_stats.message_out_bytes,
]
)
self._put_message_out(
DAQJobMessageStoreTabular(
store_config=self.config.store_config,
keys=keys,
data=data_to_send,
tag="remote",
)
)
72 changes: 66 additions & 6 deletions src/enrgdaq/daq/jobs/remote.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,54 @@
import pickle
import threading
import time
from datetime import timedelta
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Optional

import msgspec
import zmq
from msgspec import Struct, field

from enrgdaq.daq.base import DAQJob
from enrgdaq.daq.models import (
DEFAULT_REMOTE_TOPIC,
DAQJobConfig,
DAQJobMessage,
)
from enrgdaq.utils.time import sleep_for

DAQ_JOB_REMOTE_MAX_REMOTE_MESSAGE_ID_COUNT = 10000
DAQ_JOB_REMOTE_SLEEP_INTERVAL = 1


class SupervisorRemoteStats(Struct):
"""Statistics for a remote supervisor."""

message_in_count: int = 0
message_in_bytes: int = 0

message_out_count: int = 0
message_out_bytes: int = 0

last_active: datetime = field(default_factory=datetime.now)

def update_message_in_stats(self, message_in_bytes: int):
self.message_in_count += 1
self.message_in_bytes += message_in_bytes
self.last_active = datetime.now()

def update_message_out_stats(self, message_out_bytes: int):
self.message_out_count += 1
self.message_out_bytes += message_out_bytes
self.last_active = datetime.now()


class DAQJobMessageStatsRemote(DAQJobMessage):
"""Message class containing remote statistics."""

stats: "DAQJobMessageStatsRemoteDict"


DAQJobMessageStatsRemoteDict = defaultdict[str, SupervisorRemoteStats]


class DAQJobRemoteConfig(DAQJobConfig):
Expand Down Expand Up @@ -59,6 +93,7 @@ class DAQJobRemote(DAQJob):
_message_class_cache: dict[str, type[DAQJobMessage]]
_remote_message_ids: set[str]
_receive_thread: threading.Thread
_remote_stats: DAQJobMessageStatsRemoteDict

def __init__(self, config: DAQJobRemoteConfig, **kwargs):
super().__init__(config, **kwargs)
Expand All @@ -83,6 +118,7 @@ def __init__(self, config: DAQJobRemoteConfig, **kwargs):
x.__name__: x for x in DAQJobMessage.__subclasses__()
}
self._remote_message_ids = set()
self._remote_stats = defaultdict(lambda: SupervisorRemoteStats())

def handle_message(self, message: DAQJobMessage) -> bool:
if (
Expand All @@ -101,10 +137,16 @@ def handle_message(self, message: DAQJobMessage) -> bool:
return True

remote_topic = message.remote_config.remote_topic or DEFAULT_REMOTE_TOPIC
remote_topic = remote_topic.encode()
packed_message = self._pack_message(message)
self._zmq_pub.send_multipart([remote_topic, packed_message])

# Update remote stats
if self._supervisor_config:
self._remote_stats[
self._supervisor_config.supervisor_id
].update_message_out_stats(len(packed_message) + len(remote_topic))

self._zmq_pub.send_multipart(
[remote_topic.encode(), self._pack_message(message)]
)
self._logger.debug(
f"Sent message '{type(message).__name__}' to topic '{remote_topic}'"
)
Expand Down Expand Up @@ -169,18 +211,33 @@ def _start_receive_thread(self, remote_urls: list[str]):
# remote message_in -> message_out
self.message_out.put(recv_message)

# Update remote stats
if self._supervisor_config:
self._remote_stats[
self._supervisor_config.supervisor_id
].update_message_in_stats(len(message))
if (
recv_message.daq_job_info
and recv_message.daq_job_info.supervisor_config
):
self._remote_stats[
recv_message.daq_job_info.supervisor_config.supervisor_id
].update_message_out_stats(len(message))

def start(self):
"""
Start the receive thread and the DAQ job.
"""
self._receive_thread.start()

while True:
start_time = datetime.now()
if not self._receive_thread.is_alive():
raise RuntimeError("Receive thread died")
# message_in -> remote message_out
self.consume()
time.sleep(0.1)
self._send_remote_stats_message()
sleep_for(DAQ_JOB_REMOTE_SLEEP_INTERVAL, start_time)

def _pack_message(self, message: DAQJobMessage, use_pickle: bool = True) -> bytes:
"""
Expand Down Expand Up @@ -232,6 +289,9 @@ def _unpack_message(self, message: bytes) -> DAQJobMessage:
self._remote_message_ids.pop()
return res

def _send_remote_stats_message(self):
self._put_message_out(DAQJobMessageStatsRemote(self._remote_stats))

def __del__(self):
"""
Destructor for DAQJobRemote.
Expand Down

0 comments on commit b1ee017

Please sign in to comment.