Skip to content

Commit

Permalink
Gracefully migrate transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
dogversioning committed Jul 10, 2024
1 parent ba64146 commit e4cb339
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 1 deletion.
4 changes: 4 additions & 0 deletions cumulus_library/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""Package metadata"""

from .base_utils import StudyConfig
from .study_manifest import StudyManifest

__all__ = ["StudyConfig", "StudyManifest"]
__version__ = "3.0.0"
36 changes: 35 additions & 1 deletion cumulus_library/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from cumulus_library import (
__version__,
base_utils,
databases,
enums,
errors,
study_manifest,
Expand Down Expand Up @@ -86,4 +87,37 @@ def _log_table(
dataset=dataset,
type_casts=table.type_casts,
)
config.db.cursor().execute(query)
cursor = config.db.cursor()
try:
cursor.execute(query)
except Exception as e:
# Migrating logging tables
if "lib_transactions" in table_name:
cols = cursor.execute(
"SELECT column_name FROM information_schema.columns "
f"WHERE table_name ='{table_name}' "
f"AND table_schema ='{db_schema}'"
).fetchall()
cols = [col[0] for col in cols]
# Table schema pre-v3 library release
if sorted(cols) == [
"event_time",
"library_version",
"status",
"study_name",
]:
alter_query = ""
if isinstance(config.db, databases.AthenaDatabaseBackend):
alter_query = (
f"ALTER TABLE {db_schema}.{table_name} "
"ADD COLUMNS(message string)"
)
elif isinstance(config.db, databases.DuckDatabaseBackend):
alter_query = (
f"ALTER TABLE {db_schema}.{table_name} "
"ADD COLUMN message varchar"
)
cursor.execute(alter_query)
cursor.execute(query)
else:
raise e
77 changes: 77 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
from contextlib import nullcontext as does_not_raise
from datetime import datetime
from unittest import mock

import pytest
from freezegun import freeze_time

from cumulus_library import (
__version__,
base_utils,
databases,
enums,
errors,
log_utils,
Expand Down Expand Up @@ -88,6 +91,80 @@ def test_transactions(mock_db, schema, study, status, message, expects, raises):
assert log == expects


def test_migrate_transactions_duckdb(mock_db_config):
query = base_templates.get_ctas_empty_query(
"main",
"study_valid__lib_transactions",
["study_name", "library_version", "status", "event_time"],
["varchar", "varchar", "varchar", "timestamp"],
)
mock_db_config.db.cursor().execute(query)
manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/")
with does_not_raise():
log_utils.log_transaction(
config=mock_db_config,
manifest=manifest,
status="debug",
message="message",
)
cols = (
mock_db_config.db.cursor()
.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_name ='study_valid__lib_transactions' "
"AND table_schema ='main'"
)
.fetchall()
)
assert len(cols) == 5
assert ("message",) in cols


@mock.patch.dict(
os.environ,
clear=True,
)
@mock.patch("pyathena.connect")
def test_migrate_transactions_athena(mock_pyathena):
mock_fetchall = mock.MagicMock()
mock_fetchall.fetchall.side_effect = [
[("event_time",), ("library_version",), ("status",), ("study_name",)],
[
("event_time",),
("library_version",),
("status",),
("study_name",),
("message",),
],
]
mock_pyathena.return_value.cursor.return_value.execute.side_effect = [
Exception,
mock_fetchall,
None,
None,
]

db = databases.AthenaDatabaseBackend(
region="test",
work_group="test",
profile="test",
schema_name="test",
)
manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/")
config = base_utils.StudyConfig(schema="test", db=db)
log_utils.log_transaction(
config=config,
manifest=manifest,
status="debug",
message="message",
)
expected = (
"ALTER TABLE test.study_valid__lib_transactions" " ADD COLUMNS(message string)"
)
call_args = mock_pyathena.return_value.cursor.return_value.execute.call_args_list
assert expected == call_args[2][0][0]


@freeze_time("2024-01-01")
@pytest.mark.parametrize(
"schema,study,table_type,table_name,view_type,expects,raises",
Expand Down

0 comments on commit e4cb339

Please sign in to comment.