diff --git a/cumulus_library/__init__.py b/cumulus_library/__init__.py index 90816ccd..3873fff8 100644 --- a/cumulus_library/__init__.py +++ b/cumulus_library/__init__.py @@ -1,3 +1,7 @@ """Package metadata""" +from .base_utils import StudyConfig +from .study_manifest import StudyManifest + +__all__ = ["StudyConfig", "StudyManifest"] __version__ = "3.0.0" diff --git a/cumulus_library/log_utils.py b/cumulus_library/log_utils.py index b44b3dc4..8ed1877b 100644 --- a/cumulus_library/log_utils.py +++ b/cumulus_library/log_utils.py @@ -3,6 +3,7 @@ from cumulus_library import ( __version__, base_utils, + databases, enums, errors, study_manifest, @@ -86,4 +87,37 @@ def _log_table( dataset=dataset, type_casts=table.type_casts, ) - config.db.cursor().execute(query) + cursor = config.db.cursor() + try: + cursor.execute(query) + except Exception as e: + # Migrating logging tables + if "lib_transactions" in table_name: + cols = cursor.execute( + "SELECT column_name FROM information_schema.columns " + f"WHERE table_name ='{table_name}' " + f"AND table_schema ='{db_schema}'" + ).fetchall() + cols = [col[0] for col in cols] + # Table schema pre-v3 library release + if sorted(cols) == [ + "event_time", + "library_version", + "status", + "study_name", + ]: + alter_query = "" + if isinstance(config.db, databases.AthenaDatabaseBackend): + alter_query = ( + f"ALTER TABLE {db_schema}.{table_name} " + "ADD COLUMNS(message string)" + ) + elif isinstance(config.db, databases.DuckDatabaseBackend): + alter_query = ( + f"ALTER TABLE {db_schema}.{table_name} " + "ADD COLUMN message varchar" + ) + cursor.execute(alter_query) + cursor.execute(query) + else: + raise e diff --git a/tests/test_logging.py b/tests/test_logging.py index 2d3c70e1..970ce1f8 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,5 +1,7 @@ +import os from contextlib import nullcontext as does_not_raise from datetime import datetime +from unittest import mock import pytest from freezegun import freeze_time @@ -7,6 +9,7 @@ from cumulus_library import ( __version__, base_utils, + databases, enums, errors, log_utils, @@ -88,6 +91,80 @@ def test_transactions(mock_db, schema, study, status, message, expects, raises): assert log == expects +def test_migrate_transactions_duckdb(mock_db_config): + query = base_templates.get_ctas_empty_query( + "main", + "study_valid__lib_transactions", + ["study_name", "library_version", "status", "event_time"], + ["varchar", "varchar", "varchar", "timestamp"], + ) + mock_db_config.db.cursor().execute(query) + manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/") + with does_not_raise(): + log_utils.log_transaction( + config=mock_db_config, + manifest=manifest, + status="debug", + message="message", + ) + cols = ( + mock_db_config.db.cursor() + .execute( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name ='study_valid__lib_transactions' " + "AND table_schema ='main'" + ) + .fetchall() + ) + assert len(cols) == 5 + assert ("message",) in cols + + +@mock.patch.dict( + os.environ, + clear=True, +) +@mock.patch("pyathena.connect") +def test_migrate_transactions_athena(mock_pyathena): + mock_fetchall = mock.MagicMock() + mock_fetchall.fetchall.side_effect = [ + [("event_time",), ("library_version",), ("status",), ("study_name",)], + [ + ("event_time",), + ("library_version",), + ("status",), + ("study_name",), + ("message",), + ], + ] + mock_pyathena.return_value.cursor.return_value.execute.side_effect = [ + Exception, + mock_fetchall, + None, + None, + ] + + db = databases.AthenaDatabaseBackend( + region="test", + work_group="test", + profile="test", + schema_name="test", + ) + manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/") + config = base_utils.StudyConfig(schema="test", db=db) + log_utils.log_transaction( + config=config, + manifest=manifest, + status="debug", + message="message", + ) + expected = ( + "ALTER TABLE test.study_valid__lib_transactions" " ADD COLUMNS(message string)" + ) + call_args = mock_pyathena.return_value.cursor.return_value.execute.call_args_list + assert expected == call_args[2][0][0] + + @freeze_time("2024-01-01") @pytest.mark.parametrize( "schema,study,table_type,table_name,view_type,expects,raises",