Skip to content

Commit

Permalink
Breakout connection from athena DB init (#328)
Browse files Browse the repository at this point in the history
* Breakout connection from athena DB init

* coverage
  • Loading branch information
dogversioning authored Dec 16, 2024
1 parent 59b2263 commit 38ccdc8
Show file tree
Hide file tree
Showing 13 changed files with 59 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cumulus_library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from cumulus_library.study_manifest import StudyManifest

__all__ = ["BaseTableBuilder", "CountsBuilder", "StudyConfig", "StudyManifest"]
__version__ = "4.1.2"
__version__ = "4.1.3"
13 changes: 10 additions & 3 deletions cumulus_library/databases/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,22 @@ def __init__(self, region: str, work_group: str, profile: str, schema_name: str)
self.work_group = work_group
self.profile = profile
self.schema_name = schema_name
self.connection = None

def connect(self):
# the profile may not be required, provided the above three AWS env vars
# are set. If both are present, the env vars take precedence
connect_kwargs = {}
if self.profile is not None:
connect_kwargs["profile_name"] = self.profile

for aws_env_name in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
]:
if aws_env_val := os.environ.get(aws_env_name):
connect_kwargs[aws_env_name.lower()] = aws_env_val

self.connection = pyathena.connect(
region_name=self.region,
work_group=self.work_group,
Expand Down Expand Up @@ -102,8 +105,11 @@ def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
output.append((column[0], pyarrow.int64()))
case "double":
output.append((column[0], pyarrow.float64()))
# This is future proofing - we don't see this type currently.
case "decimal":
output.append((column[0], pyarrow.decimal128(column[4], column[5])))
output.append( # pragma: no cover
(column[0], pyarrow.decimal128(column[4], column[5]))
)
case "boolean":
output.append((column[0], pyarrow.bool_()))
case "date":
Expand Down Expand Up @@ -168,7 +174,8 @@ def create_schema(self, schema_name) -> None:
glue_client.create_database(DatabaseInput={"Name": schema_name})

def close(self) -> None:
return self.connection.close() # pragma: no cover
if self.connection is not None: # pragma: no cover
self.connection.close()


class AthenaParser(base.DatabaseParser):
Expand Down
4 changes: 4 additions & 0 deletions cumulus_library/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def __init__(self, schema_name: str):
# technology
self.db_type = None

@abc.abstractmethod
def connect(self):
"""Initiates connection configuration of the database"""

@abc.abstractmethod
def cursor(self) -> DatabaseCursor:
"""Returns a connection to the backing database"""
Expand Down
11 changes: 8 additions & 3 deletions cumulus_library/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ class DuckDatabaseBackend(base.DatabaseBackend):
def __init__(self, db_file: str, schema_name: str | None = None):
super().__init__("main")
self.db_type = "duckdb"
self.db_file = db_file
self.connection = None

def connect(self):
"""Connects to the local duckdb database"""
# As of the 1.0 duckdb release, local scopes, where schema names can be provided
# as configuration to duckdb.connect, are not supported.
# https://duckdb.org/docs/sql/statements/set.html#syntax

# This is where the connection config would be supplied when it is supported
self.connection = duckdb.connect(db_file)
self.connection = duckdb.connect(self.db_file)
# Aliasing Athena's as_pandas to duckDB's df cast
duckdb.DuckDBPyConnection.as_pandas = duckdb.DuckDBPyConnection.df

Expand Down Expand Up @@ -208,7 +212,8 @@ def create_schema(self, schema_name):
self.connection.sql(f"CREATE SCHEMA {schema_name}")

def close(self) -> None:
self.connection.close()
if self.connection is not None:
self.connection.close()


class DuckDbParser(base.DatabaseParser):
Expand Down
11 changes: 7 additions & 4 deletions cumulus_library/databases/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,12 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str):
# TODO: reevaluate as DuckDB's local schema support evolves.
# https://duckdb.org/docs/sql/statements/set.html#syntax
if not (args.get("schema_name") is None or args["schema_name"] == "main"):
print(
print( # pragma: no cover
"Warning - local schema names are not yet supported by duckDB's "
"python library - using 'main' instead"
)
schema_name = "main"
backend = duckdb.DuckDatabaseBackend(args["database"])
if load_ndjson_dir:
backend.insert_tables(read_ndjson_dir(load_ndjson_dir))
elif db_config.db_type == "athena":
if (
args.get("schema_name") is not None
Expand All @@ -110,5 +108,10 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str):
sys.exit("Loading an ndjson dir is not supported with --db-type=athena.")
else:
raise errors.CumulusLibraryError(f"'{db_config.db_type}' is not a supported database.")

if "prepare" not in args.keys():
backend.connect()
elif not args["prepare"]:
backend.connect()
if backend.connection is not None and db_config.db_type == "duckdb" and load_ndjson_dir:
backend.insert_tables(read_ndjson_dir(load_ndjson_dir))
return (backend, schema_name)
1 change: 1 addition & 0 deletions tests/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_upload_parquet_response_handling(mock_session):
profile="profile",
schema_name="db_schema",
)
db.connect()
client = mock.MagicMock()
with open(path / "test_data/aws/boto3.client.athena.get_work_group.json") as f:
client.get_work_group.return_value = json.load(f)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_cli_path_mapping(mock_load_json, monkeypatch, tmp_path, args, raises, e
args = duckdb_args(args, tmp_path)
cli.main(cli_args=args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
assert (expected,) in db.cursor().execute("show tables").fetchall()


Expand All @@ -140,6 +141,7 @@ def test_count_builder_mapping(tmp_path):
)
cli.main(cli_args=args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
assert [
("study_python_counts_valid__lib_transactions",),
("study_python_counts_valid__table1",),
Expand Down Expand Up @@ -271,6 +273,7 @@ def test_clean(tmp_path, args, expected, raises):
with mock.patch.object(builtins, "input", lambda _: "y"):
cli.main(cli_args=duckdb_args(args, tmp_path))
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
for table in db.cursor().execute("show tables").fetchall():
assert expected not in table

Expand Down Expand Up @@ -455,6 +458,7 @@ def test_cli_executes_queries(tmp_path, build_args, export_args, expected_tables
cli.main(cli_args=export_args)

db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
found_tables = (
db.cursor()
.execute("SELECT table_schema,table_name FROM information_schema.tables")
Expand Down Expand Up @@ -541,6 +545,7 @@ def test_cli_transactions(tmp_path, study, finishes, raises):
]
cli.main(cli_args=args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/{study}_duck.db")
db.connect()
query = db.cursor().execute(f"SELECT * from {study}__lib_transactions").fetchall()
assert query[0][2] == "started"
if finishes:
Expand Down Expand Up @@ -579,6 +584,7 @@ def test_cli_stats_rebuild(tmp_path):
cli.main(cli_args=[*arg_list, f"{tmp_path}/export"])
cli.main(cli_args=[*arg_list, f"{tmp_path}/export", "--statistics"])
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
expected = (
db.cursor()
.execute(
Expand Down Expand Up @@ -690,6 +696,7 @@ def test_cli_umls_parsing(mock_config, mode, tmp_path):
def test_cli_single_builder(tmp_path):
cli.main(cli_args=duckdb_args(["build", "--builder=patient", "--target=core"], tmp_path))
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
tables = {x[0] for x in db.cursor().execute("show tables").fetchall()}
assert {
"core__patient",
Expand All @@ -708,6 +715,7 @@ def test_cli_finds_study_from_manifest_prefix(tmp_path):
)
)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
tables = {x[0] for x in db.cursor().execute("show tables").fetchall()}
assert "study_different_name__table" in tables

Expand Down Expand Up @@ -806,6 +814,7 @@ def test_dedicated_schema(tmp_path):
cli.main(cli_args=core_build_args)
cli.main(cli_args=build_args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
tables = (
db.cursor()
.execute("SELECT table_schema,table_name FROM information_schema.tables")
Expand Down
18 changes: 13 additions & 5 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,25 @@ def test_pyarrow_types_from_sql(db, data, expected, raises):
does_not_raise(),
),
(
{**{"db_type": "athena"}, **ATHENA_KWARGS},
{**{"db_type": "athena", "prepare": False}, **ATHENA_KWARGS},
databases.AthenaDatabaseBackend,
does_not_raise(),
),
(
{**{"db_type": "athena", "database": "test"}, **ATHENA_KWARGS},
{**{"db_type": "athena", "database": "test", "prepare": False}, **ATHENA_KWARGS},
databases.AthenaDatabaseBackend,
does_not_raise(),
),
(
{**{"db_type": "athena", "database": "testtwo"}, **ATHENA_KWARGS},
{**{"db_type": "athena", "database": "testtwo", "prepare": False}, **ATHENA_KWARGS},
databases.AthenaDatabaseBackend,
pytest.raises(SystemExit),
),
(
{**{"db_type": "athena", "load_ndjson_dir": "file.json"}, **ATHENA_KWARGS},
{
**{"db_type": "athena", "load_ndjson_dir": "file.json", "prepare": False},
**ATHENA_KWARGS,
},
databases.AthenaDatabaseBackend,
pytest.raises(SystemExit),
),
Expand All @@ -253,6 +256,7 @@ def test_pyarrow_types_from_sql(db, data, expected, raises):
def test_create_db_backend(args, expected_type, raises):
with raises:
db, schema = databases.create_db_backend(args)
db.connect()
assert isinstance(db, expected_type)
if args.get("schema_name"):
assert args["schema_name"] == schema
Expand Down Expand Up @@ -347,6 +351,7 @@ def test_upload_file_athena(mock_botocore, args, sse, keycount, expected, raises
mock_clientobj.get_work_group.return_value = mock_data
mock_clientobj.list_objects_v2.return_value = {"KeyCount": keycount}
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
with raises:
location = db.upload_file(**args)
assert location == expected
Expand Down Expand Up @@ -383,6 +388,7 @@ def test_athena_pandas_cursor(mock_pyathena):
(None, "B", None, None, None),
)
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
res, desc = db.execute_as_pandas("ignored query")
assert res.equals(
pandas.DataFrame(
Expand All @@ -398,6 +404,7 @@ def test_athena_pandas_cursor(mock_pyathena):
@mock.patch("pyathena.connect")
def test_athena_parser(mock_pyathena):
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
parser = db.parser()
assert isinstance(parser, databases.AthenaParser)

Expand All @@ -411,7 +418,8 @@ def test_athena_parser(mock_pyathena):
@mock.patch("pyathena.connect")
def test_athena_env_var_priority(mock_pyathena):
os.environ["AWS_ACCESS_KEY_ID"] = "secret"
databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
assert mock_pyathena.call_args[1]["aws_access_key_id"] == "secret"


Expand Down
1 change: 1 addition & 0 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_discovery(tmp_path):
)
)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
cursor = db.cursor()
table_rows, cols = conftest.get_sorted_table_data(cursor, "discovery__code_sources")
table_rows = [tuple(x or "" for x in row) for row in table_rows]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_duckdb_core_build_and_export(tmp_path):
)
def test_duckdb_from_iso8601_timestamp(timestamp, expected):
db = databases.DuckDatabaseBackend(":memory:")
db.connect()
parsed = db.cursor().execute(f"select from_iso8601_timestamp('{timestamp}')").fetchone()[0]
assert parsed == expected

Expand Down Expand Up @@ -95,6 +96,7 @@ def test_duckdb_load_ndjson_dir(tmp_path):
def test_duckdb_table_schema():
"""Verify we can detect schemas correctly, even for nested camel case fields"""
db = databases.DuckDatabaseBackend(":memory:")
db.connect()

with tempfile.TemporaryDirectory() as tmpdir:
os.mkdir(f"{tmpdir}/observation")
Expand Down
1 change: 1 addition & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def test_migrate_transactions_athena(mock_pyathena):
profile="test",
schema_name="test",
)
db.connect()
manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/")
config = base_utils.StudyConfig(schema="test", db=db)
log_utils.log_transaction(
Expand Down
1 change: 1 addition & 0 deletions tests/test_static_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_static_file(tmp_path):
)
)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
cursor = db.cursor()
table_rows, cols = conftest.get_sorted_table_data(cursor, "study_static_file__table")
expected_cols = {"CUI", "TTY", "CODE", "SAB", "STR"}
Expand Down
1 change: 1 addition & 0 deletions tests/testbed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def build(self, study="core") -> duckdb.DuckDBPyConnection:
"db_type": "duckdb",
"database": db_file,
"load_ndjson_dir": str(self.path),
"prepare": False,
}
)
config = base_utils.StudyConfig(
Expand Down

0 comments on commit 38ccdc8

Please sign in to comment.