Skip to content

Commit

Permalink
feat: add DaqJobStoreMySQL and tests #8
Browse files Browse the repository at this point in the history
  • Loading branch information
furkan-bilgin committed Nov 14, 2024
1 parent 75e7474 commit b1e3e2e
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 0 deletions.
107 changes: 107 additions & 0 deletions src/daq/jobs/store/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import re
from collections import deque
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, cast

import pymysql

from daq.models import DAQJobConfig
from daq.store.base import DAQJobStore
from daq.store.models import DAQJobMessageStore, DAQJobStoreConfigMySQL

DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS = 15


class DAQJobStoreMySQLConfig(DAQJobConfig):
host: str
user: str
password: str
database: str
port: int = 3306


@dataclass
class MySQLWriteQueueItem:
table_name: str
keys: list[str]
rows: list[Any]


class DAQJobStoreMySQL(DAQJobStore):
config_type = DAQJobStoreMySQLConfig
allowed_store_config_types = [DAQJobStoreConfigMySQL]
allowed_message_in_types = [DAQJobMessageStore]

_write_queue: deque[MySQLWriteQueueItem]
_last_flush_date: datetime
_connection: Optional[pymysql.connections.Connection]

def __init__(self, config: DAQJobStoreMySQLConfig, **kwargs):
super().__init__(config, **kwargs)

self._write_queue = deque()
self._last_flush_date = datetime.now()
self._connection = None

def start(self):
self._connection = pymysql.connect(
host=self.config.host,
user=self.config.user,
port=self.config.port,
password=self.config.password,
database=self.config.database,
)
super().start()

def handle_message(self, message: DAQJobMessageStore) -> bool:
if not super().handle_message(message):
return False

store_config = cast(DAQJobStoreConfigMySQL, message.store_config.mysql)

# Append rows to write_queue
for row in message.data:
self._write_queue.append(
MySQLWriteQueueItem(store_config.table_name, message.keys, row)
)

return True

def _flush(self, force=False):
assert self._connection is not None
if (
datetime.now() - self._last_flush_date
).total_seconds() < DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS and not force:
return

self._connection.commit()
self._last_flush_date = datetime.now()

def _sanitize_text(self, text: str) -> str:
# replace anything but letters, numbers, and underscores with underscores
return re.sub(r"[^a-zA-Z0-9_]", "_", text)

def store_loop(self):
assert self._connection is not None
with self._connection.cursor() as cursor:
while self._write_queue:
item = self._write_queue.popleft()

table_name = self._sanitize_text(item.table_name)
keys = ",".join(self._sanitize_text(key) for key in item.keys)
values = ",".join(["%s"] * len(item.keys))
query = f"INSERT INTO {table_name} ({keys}) VALUES ({values})"
cursor.execute(
query,
tuple(item.rows),
)
self._flush()

def __del__(self):
if self._connection is not None:
self._flush(force=True)
if self._connection.open:
self._connection.close()

return super().__del__()
5 changes: 5 additions & 0 deletions src/daq/store/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class DAQJobStoreConfig(Struct, dict=True):

csv: "Optional[DAQJobStoreConfigCSV]" = None
root: "Optional[DAQJobStoreConfigROOT]" = None
mysql: "Optional[DAQJobStoreConfigMySQL]" = None

def has_store_config(self, store_type: Any) -> bool:
for key in dir(self):
Expand Down Expand Up @@ -63,6 +64,10 @@ class DAQJobStoreConfigCSV(DAQJobStoreConfigBase):
overwrite: Optional[bool] = None


class DAQJobStoreConfigMySQL(DAQJobStoreConfigBase):
table_name: str


class DAQJobStoreConfigROOT(DAQJobStoreConfigBase):
file_path: str
add_date: bool
6 changes: 6 additions & 0 deletions src/run_tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import sys
import unittest

from tests.test_csv import TestDAQJobStoreCSV
from tests.test_handle_alerts import TestDAQJobHandleAlerts
from tests.test_handle_stats import TestDAQJobHandleStats
from tests.test_healthcheck import TestDAQJobHealthcheck
from tests.test_mysql import TestDAQJobStoreMySQL
from tests.test_n1081b import TestDAQJobN1081B
from tests.test_remote import TestDAQJobRemote
from tests.test_slack import TestDAQJobAlertSlack
Expand All @@ -23,10 +25,14 @@ def run_tests():
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobHealthcheck))
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobHandleAlerts))
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobRemote))
test_suite.addTests(loader.loadTestsFromTestCase(TestDAQJobStoreMySQL))
return test_suite


if __name__ == "__main__":
# supress all logs but errors globally
logging.getLogger().setLevel(logging.ERROR)

test_suite = run_tests()
runner = unittest.TextTestRunner(verbosity=1)
result = runner.run(test_suite)
Expand Down
115 changes: 115 additions & 0 deletions src/tests/test_mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import unittest
from collections import deque
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

from daq.jobs.store.mysql import (
DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS,
DAQJobStoreMySQL,
DAQJobStoreMySQLConfig,
MySQLWriteQueueItem,
)
from daq.store.models import DAQJobMessageStore, DAQJobStoreConfigMySQL


class TestDAQJobStoreMySQL(unittest.TestCase):
def setUp(self):
self.config = DAQJobStoreMySQLConfig(
daq_job_type="",
host="localhost",
user="user",
password="password",
database="test_db",
)
self.store = DAQJobStoreMySQL(self.config)
self.store._connection = MagicMock(open=True)

@patch("time.sleep", return_value=None, side_effect=StopIteration)
@patch("pymysql.connect")
def test_start(self, mock_connect, mock_sleep):
with self.assertRaises(StopIteration):
self.store.start()
mock_connect.assert_called_once_with(
host="localhost",
user="user",
port=3306,
password="password",
database="test_db",
)
self.assertIsNotNone(self.store._connection)

def test_handle_message(self):
message = MagicMock(spec=DAQJobMessageStore)
message.store_config = MagicMock(
mysql=DAQJobStoreConfigMySQL(table_name="test_table")
)
message.keys = ["header1", "header2"]
message.data = [["row1_col1", "row1_col2"], ["row2_col1", "row2_col2"]]

result = self.store.handle_message(message)

self.assertTrue(result)
self.assertEqual(len(self.store._write_queue), 2)
self.assertEqual(self.store._write_queue[0].table_name, "test_table")
self.assertEqual(self.store._write_queue[0].keys, ["header1", "header2"])
self.assertEqual(self.store._write_queue[0].rows, ["row1_col1", "row1_col2"])

def test_flush(self):
mock_commit = MagicMock()
self.store._connection.commit = mock_commit # type: ignore
self.store._last_flush_date = datetime.now() - timedelta(
seconds=DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS + 1
)

self.store._flush()

mock_commit.assert_called_once()
self.assertAlmostEqual(
self.store._last_flush_date, datetime.now(), delta=timedelta(seconds=1)
)

def test_store_loop(self):
mock_cursor = MagicMock()
mock_cursor.return_value.__enter__.return_value.execute = MagicMock()
self.store._connection.cursor = mock_cursor # type: ignore
self.store._last_flush_date = datetime.now() - timedelta(days=7)

self.store._write_queue = deque(
[
MySQLWriteQueueItem(
"test_table", ["header1", "header2"], ["row1_col1", "row1_col2"]
),
MySQLWriteQueueItem(
"test_table", ["header1", "header2"], ["row2_col1", "row2_col2"]
),
]
)

mock_cursor_instance = mock_cursor.return_value.__enter__.return_value

self.store.store_loop()

mock_cursor_instance.execute.assert_any_call(
"INSERT INTO test_table (header1,header2) VALUES (%s,%s)",
("row1_col1", "row1_col2"),
)
mock_cursor_instance.execute.assert_any_call(
"INSERT INTO test_table (header1,header2) VALUES (%s,%s)",
("row2_col1", "row2_col2"),
)
self.assertEqual(len(self.store._write_queue), 0)

def test_del(self):
mock_commit = MagicMock()
self.store._connection.commit = mock_commit # type: ignore
mock_close = MagicMock()
self.store._connection.close = mock_close # type: ignore

del self.store

mock_commit.assert_called_once()
mock_close.assert_called_once()


if __name__ == "__main__":
unittest.main()

0 comments on commit b1e3e2e

Please sign in to comment.