From 38ccdc84e19a258ab2be7a6b93c908c025c1b140 Mon Sep 17 00:00:00 2001 From: matt garber Date: Mon, 16 Dec 2024 13:45:57 -0500 Subject: [PATCH] Breakout connection from athena DB init (#328) * Breakout connection from athena DB init * coverage --- cumulus_library/__init__.py | 2 +- cumulus_library/databases/athena.py | 13 ++++++++++--- cumulus_library/databases/base.py | 4 ++++ cumulus_library/databases/duckdb.py | 11 ++++++++--- cumulus_library/databases/utils.py | 11 +++++++---- tests/test_athena.py | 1 + tests/test_cli.py | 9 +++++++++ tests/test_databases.py | 18 +++++++++++++----- tests/test_discovery.py | 1 + tests/test_duckdb.py | 2 ++ tests/test_logging.py | 1 + tests/test_static_file.py | 1 + tests/testbed_utils.py | 1 + 13 files changed, 59 insertions(+), 16 deletions(-) diff --git a/cumulus_library/__init__.py b/cumulus_library/__init__.py index ea82c2b6..470bb7aa 100644 --- a/cumulus_library/__init__.py +++ b/cumulus_library/__init__.py @@ -6,4 +6,4 @@ from cumulus_library.study_manifest import StudyManifest __all__ = ["BaseTableBuilder", "CountsBuilder", "StudyConfig", "StudyManifest"] -__version__ = "4.1.2" +__version__ = "4.1.3" diff --git a/cumulus_library/databases/athena.py b/cumulus_library/databases/athena.py index 6ee0bb44..2469e6c8 100644 --- a/cumulus_library/databases/athena.py +++ b/cumulus_library/databases/athena.py @@ -32,12 +32,14 @@ def __init__(self, region: str, work_group: str, profile: str, schema_name: str) self.work_group = work_group self.profile = profile self.schema_name = schema_name + self.connection = None + + def connect(self): # the profile may not be required, provided the above three AWS env vars # are set. If both are present, the env vars take precedence connect_kwargs = {} if self.profile is not None: connect_kwargs["profile_name"] = self.profile - for aws_env_name in [ "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", @@ -45,6 +47,7 @@ def __init__(self, region: str, work_group: str, profile: str, schema_name: str) ]: if aws_env_val := os.environ.get(aws_env_name): connect_kwargs[aws_env_name.lower()] = aws_env_val + self.connection = pyathena.connect( region_name=self.region, work_group=self.work_group, @@ -102,8 +105,11 @@ def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list: output.append((column[0], pyarrow.int64())) case "double": output.append((column[0], pyarrow.float64())) + # This is future proofing - we don't see this type currently. case "decimal": - output.append((column[0], pyarrow.decimal128(column[4], column[5]))) + output.append( # pragma: no cover + (column[0], pyarrow.decimal128(column[4], column[5])) + ) case "boolean": output.append((column[0], pyarrow.bool_())) case "date": @@ -168,7 +174,8 @@ def create_schema(self, schema_name) -> None: glue_client.create_database(DatabaseInput={"Name": schema_name}) def close(self) -> None: - return self.connection.close() # pragma: no cover + if self.connection is not None: # pragma: no cover + self.connection.close() class AthenaParser(base.DatabaseParser): diff --git a/cumulus_library/databases/base.py b/cumulus_library/databases/base.py index 99c33f6a..4ab679db 100644 --- a/cumulus_library/databases/base.py +++ b/cumulus_library/databases/base.py @@ -120,6 +120,10 @@ def __init__(self, schema_name: str): # technology self.db_type = None + @abc.abstractmethod + def connect(self): + """Initiates connection configuration of the database""" + @abc.abstractmethod def cursor(self) -> DatabaseCursor: """Returns a connection to the backing database""" diff --git a/cumulus_library/databases/duckdb.py b/cumulus_library/databases/duckdb.py index 619c5d70..aada2ce1 100644 --- a/cumulus_library/databases/duckdb.py +++ b/cumulus_library/databases/duckdb.py @@ -25,12 +25,16 @@ class DuckDatabaseBackend(base.DatabaseBackend): def __init__(self, db_file: str, schema_name: str | None = None): super().__init__("main") self.db_type = "duckdb" + self.db_file = db_file + self.connection = None + + def connect(self): + """Connects to the local duckdb database""" # As of the 1.0 duckdb release, local scopes, where schema names can be provided # as configuration to duckdb.connect, are not supported. # https://duckdb.org/docs/sql/statements/set.html#syntax - # This is where the connection config would be supplied when it is supported - self.connection = duckdb.connect(db_file) + self.connection = duckdb.connect(self.db_file) # Aliasing Athena's as_pandas to duckDB's df cast duckdb.DuckDBPyConnection.as_pandas = duckdb.DuckDBPyConnection.df @@ -208,7 +212,8 @@ def create_schema(self, schema_name): self.connection.sql(f"CREATE SCHEMA {schema_name}") def close(self) -> None: - self.connection.close() + if self.connection is not None: + self.connection.close() class DuckDbParser(base.DatabaseParser): diff --git a/cumulus_library/databases/utils.py b/cumulus_library/databases/utils.py index 6fd11318..4807aa10 100644 --- a/cumulus_library/databases/utils.py +++ b/cumulus_library/databases/utils.py @@ -78,14 +78,12 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str): # TODO: reevaluate as DuckDB's local schema support evolves. # https://duckdb.org/docs/sql/statements/set.html#syntax if not (args.get("schema_name") is None or args["schema_name"] == "main"): - print( + print( # pragma: no cover "Warning - local schema names are not yet supported by duckDB's " "python library - using 'main' instead" ) schema_name = "main" backend = duckdb.DuckDatabaseBackend(args["database"]) - if load_ndjson_dir: - backend.insert_tables(read_ndjson_dir(load_ndjson_dir)) elif db_config.db_type == "athena": if ( args.get("schema_name") is not None @@ -110,5 +108,10 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str): sys.exit("Loading an ndjson dir is not supported with --db-type=athena.") else: raise errors.CumulusLibraryError(f"'{db_config.db_type}' is not a supported database.") - + if "prepare" not in args.keys(): + backend.connect() + elif not args["prepare"]: + backend.connect() + if backend.connection is not None and db_config.db_type == "duckdb" and load_ndjson_dir: + backend.insert_tables(read_ndjson_dir(load_ndjson_dir)) return (backend, schema_name) diff --git a/tests/test_athena.py b/tests/test_athena.py index 18e233f4..cfc51a09 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -54,6 +54,7 @@ def test_upload_parquet_response_handling(mock_session): profile="profile", schema_name="db_schema", ) + db.connect() client = mock.MagicMock() with open(path / "test_data/aws/boto3.client.athena.get_work_group.json") as f: client.get_work_group.return_value = json.load(f) diff --git a/tests/test_cli.py b/tests/test_cli.py index 46b6d39e..b2ebacb0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -119,6 +119,7 @@ def test_cli_path_mapping(mock_load_json, monkeypatch, tmp_path, args, raises, e args = duckdb_args(args, tmp_path) cli.main(cli_args=args) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() assert (expected,) in db.cursor().execute("show tables").fetchall() @@ -140,6 +141,7 @@ def test_count_builder_mapping(tmp_path): ) cli.main(cli_args=args) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() assert [ ("study_python_counts_valid__lib_transactions",), ("study_python_counts_valid__table1",), @@ -271,6 +273,7 @@ def test_clean(tmp_path, args, expected, raises): with mock.patch.object(builtins, "input", lambda _: "y"): cli.main(cli_args=duckdb_args(args, tmp_path)) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() for table in db.cursor().execute("show tables").fetchall(): assert expected not in table @@ -455,6 +458,7 @@ def test_cli_executes_queries(tmp_path, build_args, export_args, expected_tables cli.main(cli_args=export_args) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() found_tables = ( db.cursor() .execute("SELECT table_schema,table_name FROM information_schema.tables") @@ -541,6 +545,7 @@ def test_cli_transactions(tmp_path, study, finishes, raises): ] cli.main(cli_args=args) db = databases.DuckDatabaseBackend(f"{tmp_path}/{study}_duck.db") + db.connect() query = db.cursor().execute(f"SELECT * from {study}__lib_transactions").fetchall() assert query[0][2] == "started" if finishes: @@ -579,6 +584,7 @@ def test_cli_stats_rebuild(tmp_path): cli.main(cli_args=[*arg_list, f"{tmp_path}/export"]) cli.main(cli_args=[*arg_list, f"{tmp_path}/export", "--statistics"]) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() expected = ( db.cursor() .execute( @@ -690,6 +696,7 @@ def test_cli_umls_parsing(mock_config, mode, tmp_path): def test_cli_single_builder(tmp_path): cli.main(cli_args=duckdb_args(["build", "--builder=patient", "--target=core"], tmp_path)) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() tables = {x[0] for x in db.cursor().execute("show tables").fetchall()} assert { "core__patient", @@ -708,6 +715,7 @@ def test_cli_finds_study_from_manifest_prefix(tmp_path): ) ) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() tables = {x[0] for x in db.cursor().execute("show tables").fetchall()} assert "study_different_name__table" in tables @@ -806,6 +814,7 @@ def test_dedicated_schema(tmp_path): cli.main(cli_args=core_build_args) cli.main(cli_args=build_args) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() tables = ( db.cursor() .execute("SELECT table_schema,table_name FROM information_schema.tables") diff --git a/tests/test_databases.py b/tests/test_databases.py index b726fe0e..dc4c13b3 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -223,22 +223,25 @@ def test_pyarrow_types_from_sql(db, data, expected, raises): does_not_raise(), ), ( - {**{"db_type": "athena"}, **ATHENA_KWARGS}, + {**{"db_type": "athena", "prepare": False}, **ATHENA_KWARGS}, databases.AthenaDatabaseBackend, does_not_raise(), ), ( - {**{"db_type": "athena", "database": "test"}, **ATHENA_KWARGS}, + {**{"db_type": "athena", "database": "test", "prepare": False}, **ATHENA_KWARGS}, databases.AthenaDatabaseBackend, does_not_raise(), ), ( - {**{"db_type": "athena", "database": "testtwo"}, **ATHENA_KWARGS}, + {**{"db_type": "athena", "database": "testtwo", "prepare": False}, **ATHENA_KWARGS}, databases.AthenaDatabaseBackend, pytest.raises(SystemExit), ), ( - {**{"db_type": "athena", "load_ndjson_dir": "file.json"}, **ATHENA_KWARGS}, + { + **{"db_type": "athena", "load_ndjson_dir": "file.json", "prepare": False}, + **ATHENA_KWARGS, + }, databases.AthenaDatabaseBackend, pytest.raises(SystemExit), ), @@ -253,6 +256,7 @@ def test_pyarrow_types_from_sql(db, data, expected, raises): def test_create_db_backend(args, expected_type, raises): with raises: db, schema = databases.create_db_backend(args) + db.connect() assert isinstance(db, expected_type) if args.get("schema_name"): assert args["schema_name"] == schema @@ -347,6 +351,7 @@ def test_upload_file_athena(mock_botocore, args, sse, keycount, expected, raises mock_clientobj.get_work_group.return_value = mock_data mock_clientobj.list_objects_v2.return_value = {"KeyCount": keycount} db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS) + db.connect() with raises: location = db.upload_file(**args) assert location == expected @@ -383,6 +388,7 @@ def test_athena_pandas_cursor(mock_pyathena): (None, "B", None, None, None), ) db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS) + db.connect() res, desc = db.execute_as_pandas("ignored query") assert res.equals( pandas.DataFrame( @@ -398,6 +404,7 @@ def test_athena_pandas_cursor(mock_pyathena): @mock.patch("pyathena.connect") def test_athena_parser(mock_pyathena): db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS) + db.connect() parser = db.parser() assert isinstance(parser, databases.AthenaParser) @@ -411,7 +418,8 @@ def test_athena_parser(mock_pyathena): @mock.patch("pyathena.connect") def test_athena_env_var_priority(mock_pyathena): os.environ["AWS_ACCESS_KEY_ID"] = "secret" - databases.AthenaDatabaseBackend(**ATHENA_KWARGS) + db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS) + db.connect() assert mock_pyathena.call_args[1]["aws_access_key_id"] == "secret" diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 48bb7502..86a59100 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -40,6 +40,7 @@ def test_discovery(tmp_path): ) ) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() cursor = db.cursor() table_rows, cols = conftest.get_sorted_table_data(cursor, "discovery__code_sources") table_rows = [tuple(x or "" for x in row) for row in table_rows] diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index 1e8d4764..ad4a4a45 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -59,6 +59,7 @@ def test_duckdb_core_build_and_export(tmp_path): ) def test_duckdb_from_iso8601_timestamp(timestamp, expected): db = databases.DuckDatabaseBackend(":memory:") + db.connect() parsed = db.cursor().execute(f"select from_iso8601_timestamp('{timestamp}')").fetchone()[0] assert parsed == expected @@ -95,6 +96,7 @@ def test_duckdb_load_ndjson_dir(tmp_path): def test_duckdb_table_schema(): """Verify we can detect schemas correctly, even for nested camel case fields""" db = databases.DuckDatabaseBackend(":memory:") + db.connect() with tempfile.TemporaryDirectory() as tmpdir: os.mkdir(f"{tmpdir}/observation") diff --git a/tests/test_logging.py b/tests/test_logging.py index 0c619a17..a47f1a53 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -151,6 +151,7 @@ def test_migrate_transactions_athena(mock_pyathena): profile="test", schema_name="test", ) + db.connect() manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/") config = base_utils.StudyConfig(schema="test", db=db) log_utils.log_transaction( diff --git a/tests/test_static_file.py b/tests/test_static_file.py index ff89ec77..2ff6e8be 100644 --- a/tests/test_static_file.py +++ b/tests/test_static_file.py @@ -18,6 +18,7 @@ def test_static_file(tmp_path): ) ) db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db") + db.connect() cursor = db.cursor() table_rows, cols = conftest.get_sorted_table_data(cursor, "study_static_file__table") expected_cols = {"CUI", "TTY", "CODE", "SAB", "STR"} diff --git a/tests/testbed_utils.py b/tests/testbed_utils.py index fd6c986e..16bd2177 100644 --- a/tests/testbed_utils.py +++ b/tests/testbed_utils.py @@ -247,6 +247,7 @@ def build(self, study="core") -> duckdb.DuckDBPyConnection: "db_type": "duckdb", "database": db_file, "load_ndjson_dir": str(self.path), + "prepare": False, } ) config = base_utils.StudyConfig(