-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
DaqJobStoreMySQL
and tests #8
- Loading branch information
1 parent
75e7474
commit b1e3e2e
Showing
4 changed files
with
233 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import re | ||
from collections import deque | ||
from dataclasses import dataclass | ||
from datetime import datetime | ||
from typing import Any, Optional, cast | ||
|
||
import pymysql | ||
|
||
from daq.models import DAQJobConfig | ||
from daq.store.base import DAQJobStore | ||
from daq.store.models import DAQJobMessageStore, DAQJobStoreConfigMySQL | ||
|
||
DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS = 15 | ||
|
||
|
||
class DAQJobStoreMySQLConfig(DAQJobConfig): | ||
host: str | ||
user: str | ||
password: str | ||
database: str | ||
port: int = 3306 | ||
|
||
|
||
@dataclass | ||
class MySQLWriteQueueItem: | ||
table_name: str | ||
keys: list[str] | ||
rows: list[Any] | ||
|
||
|
||
class DAQJobStoreMySQL(DAQJobStore): | ||
config_type = DAQJobStoreMySQLConfig | ||
allowed_store_config_types = [DAQJobStoreConfigMySQL] | ||
allowed_message_in_types = [DAQJobMessageStore] | ||
|
||
_write_queue: deque[MySQLWriteQueueItem] | ||
_last_flush_date: datetime | ||
_connection: Optional[pymysql.connections.Connection] | ||
|
||
def __init__(self, config: DAQJobStoreMySQLConfig, **kwargs): | ||
super().__init__(config, **kwargs) | ||
|
||
self._write_queue = deque() | ||
self._last_flush_date = datetime.now() | ||
self._connection = None | ||
|
||
def start(self): | ||
self._connection = pymysql.connect( | ||
host=self.config.host, | ||
user=self.config.user, | ||
port=self.config.port, | ||
password=self.config.password, | ||
database=self.config.database, | ||
) | ||
super().start() | ||
|
||
def handle_message(self, message: DAQJobMessageStore) -> bool: | ||
if not super().handle_message(message): | ||
return False | ||
|
||
store_config = cast(DAQJobStoreConfigMySQL, message.store_config.mysql) | ||
|
||
# Append rows to write_queue | ||
for row in message.data: | ||
self._write_queue.append( | ||
MySQLWriteQueueItem(store_config.table_name, message.keys, row) | ||
) | ||
|
||
return True | ||
|
||
def _flush(self, force=False): | ||
assert self._connection is not None | ||
if ( | ||
datetime.now() - self._last_flush_date | ||
).total_seconds() < DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS and not force: | ||
return | ||
|
||
self._connection.commit() | ||
self._last_flush_date = datetime.now() | ||
|
||
def _sanitize_text(self, text: str) -> str: | ||
# replace anything but letters, numbers, and underscores with underscores | ||
return re.sub(r"[^a-zA-Z0-9_]", "_", text) | ||
|
||
def store_loop(self): | ||
assert self._connection is not None | ||
with self._connection.cursor() as cursor: | ||
while self._write_queue: | ||
item = self._write_queue.popleft() | ||
|
||
table_name = self._sanitize_text(item.table_name) | ||
keys = ",".join(self._sanitize_text(key) for key in item.keys) | ||
values = ",".join(["%s"] * len(item.keys)) | ||
query = f"INSERT INTO {table_name} ({keys}) VALUES ({values})" | ||
cursor.execute( | ||
query, | ||
tuple(item.rows), | ||
) | ||
self._flush() | ||
|
||
def __del__(self): | ||
if self._connection is not None: | ||
self._flush(force=True) | ||
if self._connection.open: | ||
self._connection.close() | ||
|
||
return super().__del__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import unittest | ||
from collections import deque | ||
from datetime import datetime, timedelta | ||
from unittest.mock import MagicMock, patch | ||
|
||
from daq.jobs.store.mysql import ( | ||
DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS, | ||
DAQJobStoreMySQL, | ||
DAQJobStoreMySQLConfig, | ||
MySQLWriteQueueItem, | ||
) | ||
from daq.store.models import DAQJobMessageStore, DAQJobStoreConfigMySQL | ||
|
||
|
||
class TestDAQJobStoreMySQL(unittest.TestCase): | ||
def setUp(self): | ||
self.config = DAQJobStoreMySQLConfig( | ||
daq_job_type="", | ||
host="localhost", | ||
user="user", | ||
password="password", | ||
database="test_db", | ||
) | ||
self.store = DAQJobStoreMySQL(self.config) | ||
self.store._connection = MagicMock(open=True) | ||
|
||
@patch("time.sleep", return_value=None, side_effect=StopIteration) | ||
@patch("pymysql.connect") | ||
def test_start(self, mock_connect, mock_sleep): | ||
with self.assertRaises(StopIteration): | ||
self.store.start() | ||
mock_connect.assert_called_once_with( | ||
host="localhost", | ||
user="user", | ||
port=3306, | ||
password="password", | ||
database="test_db", | ||
) | ||
self.assertIsNotNone(self.store._connection) | ||
|
||
def test_handle_message(self): | ||
message = MagicMock(spec=DAQJobMessageStore) | ||
message.store_config = MagicMock( | ||
mysql=DAQJobStoreConfigMySQL(table_name="test_table") | ||
) | ||
message.keys = ["header1", "header2"] | ||
message.data = [["row1_col1", "row1_col2"], ["row2_col1", "row2_col2"]] | ||
|
||
result = self.store.handle_message(message) | ||
|
||
self.assertTrue(result) | ||
self.assertEqual(len(self.store._write_queue), 2) | ||
self.assertEqual(self.store._write_queue[0].table_name, "test_table") | ||
self.assertEqual(self.store._write_queue[0].keys, ["header1", "header2"]) | ||
self.assertEqual(self.store._write_queue[0].rows, ["row1_col1", "row1_col2"]) | ||
|
||
def test_flush(self): | ||
mock_commit = MagicMock() | ||
self.store._connection.commit = mock_commit # type: ignore | ||
self.store._last_flush_date = datetime.now() - timedelta( | ||
seconds=DAQ_JOB_STORE_MYSQL_FLUSH_INTERVAL_SECONDS + 1 | ||
) | ||
|
||
self.store._flush() | ||
|
||
mock_commit.assert_called_once() | ||
self.assertAlmostEqual( | ||
self.store._last_flush_date, datetime.now(), delta=timedelta(seconds=1) | ||
) | ||
|
||
def test_store_loop(self): | ||
mock_cursor = MagicMock() | ||
mock_cursor.return_value.__enter__.return_value.execute = MagicMock() | ||
self.store._connection.cursor = mock_cursor # type: ignore | ||
self.store._last_flush_date = datetime.now() - timedelta(days=7) | ||
|
||
self.store._write_queue = deque( | ||
[ | ||
MySQLWriteQueueItem( | ||
"test_table", ["header1", "header2"], ["row1_col1", "row1_col2"] | ||
), | ||
MySQLWriteQueueItem( | ||
"test_table", ["header1", "header2"], ["row2_col1", "row2_col2"] | ||
), | ||
] | ||
) | ||
|
||
mock_cursor_instance = mock_cursor.return_value.__enter__.return_value | ||
|
||
self.store.store_loop() | ||
|
||
mock_cursor_instance.execute.assert_any_call( | ||
"INSERT INTO test_table (header1,header2) VALUES (%s,%s)", | ||
("row1_col1", "row1_col2"), | ||
) | ||
mock_cursor_instance.execute.assert_any_call( | ||
"INSERT INTO test_table (header1,header2) VALUES (%s,%s)", | ||
("row2_col1", "row2_col2"), | ||
) | ||
self.assertEqual(len(self.store._write_queue), 0) | ||
|
||
def test_del(self): | ||
mock_commit = MagicMock() | ||
self.store._connection.commit = mock_commit # type: ignore | ||
mock_close = MagicMock() | ||
self.store._connection.close = mock_close # type: ignore | ||
|
||
del self.store | ||
|
||
mock_commit.assert_called_once() | ||
mock_close.assert_called_once() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |