From d71410600a9b4164981fd206a646c9ec9a6c1cb9 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Wed, 10 Jul 2024 14:50:59 -0400 Subject: [PATCH 1/2] Gracefully migrate transactions --- cumulus_library/__init__.py | 4 ++ cumulus_library/log_utils.py | 36 ++++++++++++++++- tests/test_logging.py | 77 ++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/cumulus_library/__init__.py b/cumulus_library/__init__.py index 90816ccd..3873fff8 100644 --- a/cumulus_library/__init__.py +++ b/cumulus_library/__init__.py @@ -1,3 +1,7 @@ """Package metadata""" +from .base_utils import StudyConfig +from .study_manifest import StudyManifest + +__all__ = ["StudyConfig", "StudyManifest"] __version__ = "3.0.0" diff --git a/cumulus_library/log_utils.py b/cumulus_library/log_utils.py index b44b3dc4..8ed1877b 100644 --- a/cumulus_library/log_utils.py +++ b/cumulus_library/log_utils.py @@ -3,6 +3,7 @@ from cumulus_library import ( __version__, base_utils, + databases, enums, errors, study_manifest, @@ -86,4 +87,37 @@ def _log_table( dataset=dataset, type_casts=table.type_casts, ) - config.db.cursor().execute(query) + cursor = config.db.cursor() + try: + cursor.execute(query) + except Exception as e: + # Migrating logging tables + if "lib_transactions" in table_name: + cols = cursor.execute( + "SELECT column_name FROM information_schema.columns " + f"WHERE table_name ='{table_name}' " + f"AND table_schema ='{db_schema}'" + ).fetchall() + cols = [col[0] for col in cols] + # Table schema pre-v3 library release + if sorted(cols) == [ + "event_time", + "library_version", + "status", + "study_name", + ]: + alter_query = "" + if isinstance(config.db, databases.AthenaDatabaseBackend): + alter_query = ( + f"ALTER TABLE {db_schema}.{table_name} " + "ADD COLUMNS(message string)" + ) + elif isinstance(config.db, databases.DuckDatabaseBackend): + alter_query = ( + f"ALTER TABLE {db_schema}.{table_name} " + "ADD COLUMN message varchar" + ) + cursor.execute(alter_query) + cursor.execute(query) + else: + raise e diff --git a/tests/test_logging.py b/tests/test_logging.py index 2d3c70e1..970ce1f8 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,5 +1,7 @@ +import os from contextlib import nullcontext as does_not_raise from datetime import datetime +from unittest import mock import pytest from freezegun import freeze_time @@ -7,6 +9,7 @@ from cumulus_library import ( __version__, base_utils, + databases, enums, errors, log_utils, @@ -88,6 +91,80 @@ def test_transactions(mock_db, schema, study, status, message, expects, raises): assert log == expects +def test_migrate_transactions_duckdb(mock_db_config): + query = base_templates.get_ctas_empty_query( + "main", + "study_valid__lib_transactions", + ["study_name", "library_version", "status", "event_time"], + ["varchar", "varchar", "varchar", "timestamp"], + ) + mock_db_config.db.cursor().execute(query) + manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/") + with does_not_raise(): + log_utils.log_transaction( + config=mock_db_config, + manifest=manifest, + status="debug", + message="message", + ) + cols = ( + mock_db_config.db.cursor() + .execute( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name ='study_valid__lib_transactions' " + "AND table_schema ='main'" + ) + .fetchall() + ) + assert len(cols) == 5 + assert ("message",) in cols + + +@mock.patch.dict( + os.environ, + clear=True, +) +@mock.patch("pyathena.connect") +def test_migrate_transactions_athena(mock_pyathena): + mock_fetchall = mock.MagicMock() + mock_fetchall.fetchall.side_effect = [ + [("event_time",), ("library_version",), ("status",), ("study_name",)], + [ + ("event_time",), + ("library_version",), + ("status",), + ("study_name",), + ("message",), + ], + ] + mock_pyathena.return_value.cursor.return_value.execute.side_effect = [ + Exception, + mock_fetchall, + None, + None, + ] + + db = databases.AthenaDatabaseBackend( + region="test", + work_group="test", + profile="test", + schema_name="test", + ) + manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/") + config = base_utils.StudyConfig(schema="test", db=db) + log_utils.log_transaction( + config=config, + manifest=manifest, + status="debug", + message="message", + ) + expected = ( + "ALTER TABLE test.study_valid__lib_transactions" " ADD COLUMNS(message string)" + ) + call_args = mock_pyathena.return_value.cursor.return_value.execute.call_args_list + assert expected == call_args[2][0][0] + + @freeze_time("2024-01-01") @pytest.mark.parametrize( "schema,study,table_type,table_name,view_type,expects,raises", From 2abfe1a5ab56ef13cae6dfc338701c03b555782c Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Thu, 11 Jul 2024 09:22:40 -0400 Subject: [PATCH 2/2] Switch to catching operational errors --- cumulus_library/databases.py | 7 +++++-- cumulus_library/log_utils.py | 2 +- tests/test_logging.py | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cumulus_library/databases.py b/cumulus_library/databases.py index 55d0ad36..a1c7a958 100644 --- a/cumulus_library/databases.py +++ b/cumulus_library/databases.py @@ -274,7 +274,7 @@ def parser(self) -> DatabaseParser: return AthenaParser() def operational_errors(self) -> tuple[Exception]: - return (pyathena.OperationalError,) # pragma: no cover + return (pyathena.OperationalError,) def col_parquet_types_from_pandas(self, field_types: list) -> list: output = [] @@ -609,7 +609,10 @@ def parser(self) -> DatabaseParser: return DuckDbParser() def operational_errors(self) -> tuple[Exception]: - return (duckdb.OperationalError,) # pragma: no cover + return ( + duckdb.OperationalError, + duckdb.BinderException, + ) def create_schema(self, schema_name): """Creates a new schema object inside the database""" diff --git a/cumulus_library/log_utils.py b/cumulus_library/log_utils.py index 8ed1877b..7d01299a 100644 --- a/cumulus_library/log_utils.py +++ b/cumulus_library/log_utils.py @@ -90,7 +90,7 @@ def _log_table( cursor = config.db.cursor() try: cursor.execute(query) - except Exception as e: + except config.db.operational_errors() as e: # Migrating logging tables if "lib_transactions" in table_name: cols = cursor.execute( diff --git a/tests/test_logging.py b/tests/test_logging.py index 970ce1f8..000c607e 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,6 +3,7 @@ from datetime import datetime from unittest import mock +import pyathena import pytest from freezegun import freeze_time @@ -138,7 +139,7 @@ def test_migrate_transactions_athena(mock_pyathena): ], ] mock_pyathena.return_value.cursor.return_value.execute.side_effect = [ - Exception, + pyathena.OperationalError, mock_fetchall, None, None,