Skip to content

Commit

Permalink
Proper multithreaded stats (#60)
Browse files Browse the repository at this point in the history
- Stats have their own thread for printing at a configurable cadence
- There are 1 hr trailing window and cumulative stats
- File rate is now reported as files/24h
- Stats live in a separate file from the downloader
- Stats code is (hopefully) easier to read
  • Loading branch information
mtauraso authored Sep 11, 2024
1 parent fde4c18 commit 3ee8380
Show file tree
Hide file tree
Showing 2 changed files with 397 additions and 158 deletions.
165 changes: 7 additions & 158 deletions src/fibad/download.py
Original file line number Diff line number Diff line change
@@ -1,171 +1,17 @@
import datetime
import itertools
import logging
import time
import urllib.request
from pathlib import Path
from threading import Lock, Thread
from typing import Optional, Union
from threading import Thread
from typing import Optional

from astropy.table import Table, hstack

import fibad.downloadCutout.downloadCutout as dC
from fibad.download_stats import DownloadStats

logger = logging.getLogger(__name__)


class DownloadStats:
"""Subsytem for keeping statistics on downloads:
Accumulates time spent on request, responses as well as sizes for same and number of snapshots
Can be used as a context manager for pretty printing.
"""

def __init__(self, print_interval_s=30):
self.lock = Lock()
self.stats = {
"request_duration": datetime.timedelta(), # Time from request sent to first byte from the server
"response_duration": datetime.timedelta(), # Total time spent recieving and processing a response
"request_size_bytes": 0, # Total size of all requests
"response_size_bytes": 0, # Total size of all responses
"snapshots": 0, # Number of fits snapshots downloaded
}

# Reference count active threads and whether we've started
self.active_threads = 0
self.num_threads = 0
self.data_start = None

# How often the watcher thread prints (seconds)
self.print_interval_s = print_interval_s

# Start our watcher thread to print stats to the log
self.watcher_thread = Thread(
target=self._watcher_thread, name="stats watcher thread", args=(logging.INFO,), daemon=True
)
self.watcher_thread.start()

def __enter__(self):
# Count how many threads are using stats
with self.lock:
self.active_threads += 1
self.num_threads += 1

return self.hook

def __exit__(self, exc_type, exc_value, traceback):
# Count how many threads are using stats
with self.lock:
self.active_threads -= 1

def _stat_accumulate(self, name: str, value: Union[int, datetime.timedelta]):
"""Accumulate a sum into the global stats dict
Parameters
----------
name : str
Name of the stat. Assumed to exist in the dict already.
value : Union[int, datetime.timedelta]
How much time or count to add to the stat
"""
self.stats[name] += value

def _watcher_thread(self, log_level):
# Simple polling loop to print
while self.active_threads != 0 or not self.data_start:
if self.data_start:
self._print_stats(log_level)
time.sleep(self.print_interval_s)

def _print_stats(self, log_level):
"""Print the accumulated stats including bandwidth calculated from duration and sizes
This prints out multiple lines with `\r` at the end in order to create a continuously updating
line of text during download if your terminal supports it.
If you use this class as a context manager, the end of context will output a newline, perserving
the last line of stats in your terminal
"""

def _div(num, denom, default=0.0):
return num / denom if denom != 0 else default

with self.lock:
now = datetime.datetime.now()

wall_clock_dur_s = (now - self.data_start).total_seconds() if self.data_start else 0

# This is the duration across all threads added up
total_dur_s = (self.stats["request_duration"] + self.stats["response_duration"]).total_seconds()

resp_s = self.stats["response_duration"].total_seconds()
down_rate_mb_s = _div(self.stats["response_size_bytes"] / (1024**2), resp_s)
down_rate_mb_s_overall = _div(self.stats["response_size_bytes"] / (1024**2), wall_clock_dur_s)

req_s = self.stats["request_duration"].total_seconds()
up_rate_mb_s = _div(self.stats["request_size_bytes"] / (1024**2), req_s)

snapshot_rate = _div(self.stats["snapshots"], wall_clock_dur_s)
snapshot_rate_thread = _div(self.stats["snapshots"], total_dur_s)

connnection_efficiency = _div(total_dur_s, wall_clock_dur_s * self.num_threads)

thread_avg_dur = _div(total_dur_s, self.num_threads)

stats_message = "Overall stats: "
stats_message += f"Wall-clock Duration: {wall_clock_dur_s:.2f} s, "
stats_message += f"Files: {self.stats['snapshots']}, "
stats_message += f"Download rate: {down_rate_mb_s_overall:.2f} MB/s, "
stats_message += f"File rate: {snapshot_rate:.2f} files/s, "
stats_message += f"Conn eff: {connnection_efficiency:.2f}"
logger.log(log_level, stats_message)

stats_message = f"Per Thread Averages ({self.num_threads} threads): "
stats_message += f"Duration: {thread_avg_dur:.2f} s, "
stats_message += f"Upload: {up_rate_mb_s:.2f} MB/s, "
stats_message += f"Download: {down_rate_mb_s:.2f} MB/s, "
stats_message += f"File rate: {snapshot_rate_thread:.2f} files/s, "
logger.log(log_level, stats_message)

def hook(
self,
request: urllib.request.Request,
request_start: datetime.datetime,
response_start: datetime.datetime,
response_size: int,
chunk_size: int,
):
"""This hook is called on each chunk of snapshots downloaded.
It is called immediately after the server has finished responding to the
request, so datetime.datetime.now() is the end moment of the request
Parameters
----------
request : urllib.request.Request
The request object relevant to this call
request_start : datetime.datetime
The moment the request was handed off to urllib.request.urlopen()
response_start : datetime.datetime
The moment there were bytes from the server to process
response_size : int
The size of the response from the server in bytes
chunk_size : int
The number of cutout files recieved in this request
"""
now = datetime.datetime.now()

with self.lock:
if not self.data_start:
self.data_start = request_start

self._stat_accumulate("request_duration", response_start - request_start)
self._stat_accumulate("response_duration", now - response_start)
self._stat_accumulate("request_size_bytes", len(request.data))
self._stat_accumulate("response_size_bytes", response_size)
self._stat_accumulate("snapshots", chunk_size)


class Downloader:
"""Class with primarily static methods to namespace downloader related constants and functions."""

Expand Down Expand Up @@ -209,11 +55,13 @@ def run(config):
cutout_path = Path(config.get("cutout_dir")).resolve()
logger.info(f"Downloading cutouts to {cutout_path}")

logger.info("Making a list of cutouts...")
# Make a list of rects to pass to downloadCutout
rects = Downloader.create_rects(
locations, offset=0, default=Downloader.rect_from_config(config), path=cutout_path
)

logger.info("Checking the list against currently downloaded cutouts...")
# Prune any previously downloaded rects from our list using the manifest from the previous download
rects = Downloader._prune_downloaded_rects(cutout_path, rects)

Expand All @@ -240,6 +88,7 @@ def _batched(iterable, n):
while batch := tuple(itertools.islice(iterator, n)):
yield batch

logger.info("Dividing cutouts among threads...")
thread_rects = list(_batched(rects, int(len(rects) / num_threads))) if num_threads != 1 else [rects]

# Empty dictionaries for the threads to create download manifests in
Expand All @@ -248,7 +97,7 @@ def _batched(iterable, n):
shared_thread_args = (
config["username"],
config["password"],
DownloadStats(print_interval_s=config.get("stats_print_interval", 30)),
DownloadStats(print_interval_s=config.get("stats_print_interval", 60)),
)

shared_thread_kwargs = {
Expand Down
Loading

0 comments on commit 3ee8380

Please sign in to comment.