Skip to content

Commit

Permalink
Merge pull request #13 from UW-Macrostrat/database-v3
Browse files Browse the repository at this point in the history
Database v3
  • Loading branch information
davenquinn authored Jan 4, 2024
2 parents 8ae400d + e937494 commit 89bfe6e
Show file tree
Hide file tree
Showing 16 changed files with 289 additions and 158 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ publish:
poetry run mono publish

test:
poetry run pytest -s
poetry run pytest -s -x --failed-first
1 change: 0 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def pytest_addoption(parser):

@fixture(scope="session")
def docker_client():
# environ.setdefault("DOCKER_HOST", "unix:///var/run/docker.sock")
client = DockerClient.from_env()
return client

Expand Down
7 changes: 3 additions & 4 deletions database/macrostrat/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert


from .utils import (
run_sql,
get_or_create,
Expand All @@ -30,7 +31,7 @@ class Database(object):
session: Session
__inspector__ = None

def __init__(self, db_conn, app=None, echo_sql=False, **kwargs):
def __init__(self, db_conn, echo_sql=False, **kwargs):
"""
We can pass a connection string, a **Flask** application object
with the appropriate configuration, or nothing, in which
Expand All @@ -41,9 +42,7 @@ def __init__(self, db_conn, app=None, echo_sql=False, **kwargs):
compiles(Insert, "postgresql")(prefix_inserts)

log.info(f"Setting up database connection '{db_conn}'")
self.engine = create_engine(
db_conn, executemany_mode="batch", echo=echo_sql, **kwargs
)
self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
self.metadata = kwargs.get("metadata", metadata)

# Scoped session for database
Expand Down
15 changes: 3 additions & 12 deletions database/macrostrat/database/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class DatabaseMapper:
automap_error = None
_models = None
_tables = None
_prepared = False

def __init__(self, db, **kwargs):
# https://docs.sqlalchemy.org/en/13/orm/extensions/automap.html#sqlalchemy.ext.automap.AutomapBase.prepare
Expand Down Expand Up @@ -77,23 +76,15 @@ def _cache_database_map(self):
def reflect_schema(self, schema, use_cache=True):
if use_cache and self.automap_base.loaded_from_cache:
log.info("Database models for %s have been loaded from cache", schema)
self.automap_base.prepare(
self.db.engine, schema=schema, **self.reflection_kwargs
)
self.automap_base.prepare(schema=schema, **self.reflection_kwargs)
return
log.info(f"Reflecting schema {schema}")
if schema == "public":
schema = None
if not self._prepared:
self.automap_base.prepare(
self.db.engine, reflect=False, **self.reflection_kwargs
)
self._prepared = True
# Reflect tables in schemas we care about
# Note: this will not reflect views because they don't have primary keys.
self.automap_base.metadata.reflect(
bind=self.db.engine,
schema=schema,
self.automap_base.prepare(
autoload_with=self.db.engine, schema=schema, **self.reflection_kwargs
)
self._models = ModelCollection(self.automap_base.classes)
self._tables = TableCollection(self._models)
Expand Down
10 changes: 5 additions & 5 deletions database/macrostrat/database/mapper/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@

log = get_logger(__name__)


class AutomapError(Exception):
pass


class DatabaseModelCache(object):
cache_file = None

def __init__(self, cache_file=None):
self.cache_file = cache_file


@property
def _metadata_cache_filename(self):
return self.cache_file
Expand All @@ -36,9 +37,7 @@ def _cache_database_map(self, metadata):
log.info(f"Cached database models to {self.cache_file}")
except IOError:
# couldn't write the file for some reason
log.info(
f"Could not cache database models to {self.cache_file}"
)
log.info(f"Could not cache database models to {self.cache_file}")

def _load_database_map(self):
# We have hard-coded the cache file for now
Expand All @@ -54,6 +53,8 @@ def _load_database_map(self):
log.info(
f"Could not find database model cache ({self._metadata_cache_filename})"
)
except Exception as exc:
log.error(f"Error loading database model cache: {exc}")
return cached_metadata

def automap_base(self):
Expand All @@ -66,4 +67,3 @@ def automap_base(self):
base.loaded_from_cache = True
base.builder = self
return base

18 changes: 17 additions & 1 deletion database/macrostrat/database/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert, text
from sqlalchemy.dialects import postgresql
import psycopg2


_import_mode = ContextVar("import-mode", default="do-nothing")

Expand All @@ -15,7 +17,6 @@
if TYPE_CHECKING:
from ..database import Database


# https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy/62305344#62305344
@contextmanager
def on_conflict(action="restrict"):
Expand Down Expand Up @@ -57,6 +58,20 @@ def prefix_inserts(insert, compiler, **kw):
return compiler.visit_insert(insert, **kw)


_psycopg2_setup_was_run = ContextVar("psycopg2-setup-was-run", default=False)


def _setup_psycopg2_wait_callback():
"""Set up the wait callback for PostgreSQL connections. This allows for query cancellation with Ctrl-C."""
# TODO: we might want to do this only once on engine creation
# https://github.com/psycopg/psycopg2/issues/333
val = _psycopg2_setup_was_run.get()
if val:
return
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
_psycopg2_setup_was_run.set(True)


def table_exists(db: Database, table_name: str, schema: str = "public") -> bool:
"""Check if a table exists in a PostgreSQL database."""
sql = """SELECT EXISTS (
Expand All @@ -68,3 +83,4 @@ def table_exists(db: Database, table_name: str, schema: str = "public") -> bool:
return db.session.execute(
text(sql), params=dict(schema=schema, table_name=table_name)
).scalar()

47 changes: 21 additions & 26 deletions database/macrostrat/database/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from click import echo, secho
from sqlalchemy.exc import ProgrammingError, IntegrityError, InternalError
from sqlparse import split, format
from sqlalchemy.sql import ClauseElement
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import Engine, Connection, Transaction
from sqlalchemy.sql.elements import TextClause, ClauseElement
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.schema import Table
from sqlalchemy import MetaData, create_engine, text
from contextlib import contextmanager
Expand All @@ -16,6 +16,8 @@
from warnings import warn
from psycopg2.sql import SQL, Composable, Composed
from re import search
from macrostrat.utils import get_logger
from .postgresql import _setup_psycopg2_wait_callback

log = get_logger(__name__)

Expand Down Expand Up @@ -84,15 +86,7 @@ def get_dataframe(connectable, filename_or_query, **kwargs):

def pretty_print(sql, **kwargs):
for line in sql.split("\n"):
for i in [
"SELECT",
"INSERT",
"UPDATE",
"CREATE",
"DROP",
"DELETE",
"ALTER",
]:
for i in ["SELECT", "INSERT", "UPDATE", "CREATE", "DROP", "DELETE", "ALTER"]:
if not line.startswith(i):
continue
start = line.split("(")[0].strip().rstrip(";").replace(" AS", "")
Expand Down Expand Up @@ -120,6 +114,8 @@ def _get_queries(sql, interpret_as_file=None):
for i in sql:
queries.extend(_get_queries(i, interpret_as_file=interpret_as_file))
return queries
if isinstance(sql, TextClause):
return [sql]
if isinstance(sql, SQL):
return [sql]

Expand Down Expand Up @@ -172,18 +168,20 @@ def _get_cursor(connectable):
conn = connectable
if hasattr(conn, "raw_connection"):
conn = conn.raw_connection()
while hasattr(conn, "connection"):
if callable(conn.connection):
conn = conn.connection()
while hasattr(conn, "driver_connection") or hasattr(conn, "connection"):
if hasattr(conn, "driver_connection"):
conn = conn.driver_connection
else:
conn = conn.connection
if callable(conn):
conn = conn()
if hasattr(conn, "cursor"):
conn = conn.cursor()

return conn


def _get_connection(connectable):
def _get_connection(connectable) -> Connection:
if isinstance(connectable, Engine):
return connectable.connect()
if isinstance(connectable, Connection):
Expand Down Expand Up @@ -220,6 +218,8 @@ def _run_sql(connectable, sql, **kwargs):
yield from _run_sql(conn, sql, **kwargs)
return

_setup_psycopg2_wait_callback()

params = kwargs.pop("params", None)
stop_on_error = kwargs.pop("stop_on_error", False)
raise_errors = kwargs.pop("raise_errors", False)
Expand Down Expand Up @@ -258,7 +258,7 @@ def _run_sql(connectable, sql, **kwargs):
if isinstance(query, (SQL, Composed)):
query = _render_query(query, connectable)

sql_text = query
sql_text = str(query)
if isinstance(query, str):
sql_text = format(query, strip_comments=True).strip()
if sql_text == "":
Expand All @@ -273,7 +273,9 @@ def _run_sql(connectable, sql, **kwargs):
conn = _get_connection(connectable)
res = conn.exec_driver_sql(query, params)
else:
res = connectable.execute(text(query), params=params)
if not isinstance(query, TextClause):
query = text(query)
res = connectable.execute(query, params)
yield res
if trans is not None:
trans.commit()
Expand Down Expand Up @@ -445,11 +447,4 @@ def reflect_table(engine, tablename, *column_args, **kwargs):
"""
schema = kwargs.pop("schema", "public")
meta = MetaData(schema=schema)
return Table(
tablename,
meta,
*column_args,
autoload=True,
autoload_with=engine,
**kwargs,
)
return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)
12 changes: 6 additions & 6 deletions database/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ authors = ["Daven Quinn <[email protected]>"]
description = "A SQLAlchemy-based database toolkit."
name = "macrostrat.database"
packages = [{ include = "macrostrat" }]
version = "2.1.3"
version = "3.0.0-beta1"

[tool.poetry.dependencies]
GeoAlchemy2 = "^0.9.4"
SQLAlchemy = "^1.4.26"
SQLAlchemy-Utils = "^0.37.0"
GeoAlchemy2 = "^0.14.0"
SQLAlchemy = "^2.0.18"
SQLAlchemy-Utils = "^0.41.1"
click = "^8.1.3"
"macrostrat.utils" = "^1.0.0"
psycopg2-binary = "^2.9.1"
psycopg2-binary = "^2.9.6"
python = "^3.8"
sqlparse = "^0.4.0"
sqlparse = "^0.4.4"

[tool.poetry.dev-dependencies]
"macrostrat.utils" = { path = "../utils", develop = true }
Expand Down
21 changes: 21 additions & 0 deletions database/test-scripts/test-long-running-query
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python
"""
Run a long-running query to test if we can interrupt it.
"""

from macrostrat.database import Database
from sys import argv, exit
from sqlalchemy.exc import OperationalError


db_conn = argv[1]
print(f"Connecting to {db_conn}")

db = Database(db_conn)

try:
db.run_sql("SELECT pg_sleep(10);")
except OperationalError as e:
if "canceling statement due to user request" in str(e):
print("Query canceled due to user request.")
exit(1)
Loading

0 comments on commit 89bfe6e

Please sign in to comment.