Skip to content

Commit

Permalink
Do not require SQL URIs to be prefixed with SQLAlchemy driver (#810)
Browse files Browse the repository at this point in the history
* Automatically set SQL driver if unset.

* Handle special SQLite URIs

* Consistently use database URI with schema.

* Interpret filepaths as SQLite URIs.

* Parse uri earlier.

* Update CHANGELOG

* Fix missing arg in refactor.

* Make utility accept Path-like objects.

* Deal with Windows paths in test case
  • Loading branch information
danielballan authored Nov 5, 2024
1 parent 0ad610f commit e52bab3
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Write the date in place of the "Unreleased" in the case a new version is release

- Drop support for Python 3.8, which is reached end of life
upstream on 7 October 2024.
- Do not require SQL database URIs to specify a "driver" (Python
library to be used for connecting).

## v0.1.0b10 (2024-10-11)

Expand Down
3 changes: 2 additions & 1 deletion tiled/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..catalog import from_uri, in_memory
from ..client.base import BaseClient
from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver
from .utils import enter_username_password as utils_enter_uname_passwd
from .utils import temp_postgres

Expand Down Expand Up @@ -152,7 +153,7 @@ async def postgresql_with_example_data_adapter(request, tmpdir):
if uri.endswith("/"):
uri = uri[:-1]
uri_with_database_name = f"{uri}/{DATABASE_NAME}"
engine = create_async_engine(uri_with_database_name)
engine = create_async_engine(ensure_specified_sql_driver(uri_with_database_name))
try:
async with engine.connect():
pass
Expand Down
75 changes: 75 additions & 0 deletions tiled/_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pathlib import Path

from ..utils import ensure_specified_sql_driver


def test_ensure_specified_sql_driver():
# Postgres
# Default driver is added if missing.
assert (
ensure_specified_sql_driver(
"postgresql://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver(
"postgresql+asyncpg://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver(
"postgresql+custom://user:password@localhost:5432/database"
)
== "postgresql+custom://user:password@localhost:5432/database"
)

# SQLite
# Default driver is added if missing.
assert (
ensure_specified_sql_driver("sqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver("sqlite+custom:////test.db")
== "sqlite+custom:////test.db"
)
# Handle SQLite :memory: URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
assert (
ensure_specified_sql_driver("sqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
# Handle SQLite relative URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
assert (
ensure_specified_sql_driver("sqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
# Filepaths are implicitly SQLite databases.
# Relative path
assert ensure_specified_sql_driver("test.db") == "sqlite+aiosqlite:///test.db"
# Path object
assert ensure_specified_sql_driver(Path("test.db")) == "sqlite+aiosqlite:///test.db"
# Relative path anchored to .
assert ensure_specified_sql_driver("./test.db") == "sqlite+aiosqlite:///test.db"
# Absolute path
assert (
ensure_specified_sql_driver(Path("/tmp/test.db"))
== f"sqlite+aiosqlite:///{Path('/tmp/test.db')}"
)
3 changes: 2 additions & 1 deletion tiled/_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..client import context
from ..client.base import BaseClient
from ..utils import ensure_specified_sql_driver

if sys.version_info < (3, 9):
import importlib_resources as resources
Expand All @@ -33,7 +34,7 @@ async def temp_postgres(uri):
if uri.endswith("/"):
uri = uri[:-1]
# Create a fresh database.
engine = create_async_engine(uri)
engine = create_async_engine(ensure_specified_sql_driver(uri))
database_name = f"tiled_test_disposable_{uuid.uuid4().hex}"
async with engine.connect() as connection:
await connection.execute(
Expand Down
5 changes: 4 additions & 1 deletion tiled/authn_database/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver

# A given process probably only has one of these at a time, but we
# key on database_settings just case in some testing context or something
Expand All @@ -16,7 +17,9 @@ def open_database_connection_pool(database_settings):
# kwargs["pool_pre_ping"] = database_settings.pool_pre_ping
# kwargs["max_overflow"] = database_settings.max_overflow
engine = create_async_engine(
database_settings.uri, connect_args=connect_args, **kwargs
ensure_specified_sql_driver(database_settings.uri),
connect_args=connect_args,
**kwargs,
)
_connection_pools[database_settings] = engine
return engine
Expand Down
12 changes: 6 additions & 6 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@
from ..server.schemas import Asset, DataSource, Management, Revision, Spec
from ..structures.core import StructureFamily
from ..utils import (
SCHEME_PATTERN,
UNCHANGED,
Conflicts,
OneShotCachedMap,
UnsupportedQueryType,
ensure_awaitable,
ensure_specified_sql_driver,
ensure_uri,
import_object,
path_from_uri,
Expand Down Expand Up @@ -1347,7 +1347,7 @@ def from_uri(
echo=DEFAULT_ECHO,
adapters_by_mimetype=None,
):
uri = str(uri)
uri = ensure_specified_sql_driver(uri)
if init_if_not_exists:
# The alembic stamping can only be does synchronously.
# The cleanest option available is to start a subprocess
Expand All @@ -1366,9 +1366,6 @@ def from_uri(
stderr = process.stderr.decode()
logging.info(f"Subprocess stdout: {stdout}")
logging.error(f"Subprocess stderr: {stderr}")
if not SCHEME_PATTERN.match(uri):
# Interpret URI as filepath.
uri = f"sqlite+aiosqlite:///{uri}"

parsed_url = make_url(uri)
if (parsed_url.get_dialect().name == "sqlite") and (
Expand All @@ -1381,7 +1378,10 @@ def from_uri(
else:
poolclass = None # defer to sqlalchemy default
engine = create_async_engine(
uri, echo=echo, json_serializer=json_serializer, poolclass=poolclass
uri,
echo=echo,
json_serializer=json_serializer,
poolclass=poolclass,
)
if engine.dialect.name == "sqlite":
event.listens_for(engine.sync_engine, "connect")(_set_sqlite_pragma)
Expand Down
9 changes: 6 additions & 3 deletions tiled/commandline/_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def initialize_database(database_uri: str):
REQUIRED_REVISION,
initialize_database,
)
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
try:
await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS)
Expand Down Expand Up @@ -71,9 +72,10 @@ def upgrade_database(
ALEMBIC_INI_TEMPLATE_PATH,
)
from ..authn_database.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
await engine.dispose()
Expand Down Expand Up @@ -107,9 +109,10 @@ def downgrade_database(
ALEMBIC_INI_TEMPLATE_PATH,
)
from ..authn_database.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
if current_revision is None:
Expand Down
12 changes: 8 additions & 4 deletions tiled/commandline/_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ def init(
from ..alembic_utils import UninitializedDatabase, check_database, stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS, REQUIRED_REVISION, initialize_database
from ..utils import SCHEME_PATTERN
from ..utils import ensure_specified_sql_driver

if not SCHEME_PATTERN.match(database):
# Interpret URI as filepath.
database = f"sqlite+aiosqlite:///{database}"
database = ensure_specified_sql_driver(database)

async def do_setup():
engine = create_async_engine(database)
Expand Down Expand Up @@ -94,6 +92,9 @@ def upgrade_database(
from ..alembic_utils import get_current_revision, upgrade
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

database_uri = ensure_specified_sql_driver(database_uri)

async def do_setup():
engine = create_async_engine(database_uri)
Expand Down Expand Up @@ -127,6 +128,9 @@ def downgrade_database(
from ..alembic_utils import downgrade, get_current_revision
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

database_uri = ensure_specified_sql_driver(database_uri)

async def do_setup():
engine = create_async_engine(database_uri)
Expand Down
6 changes: 4 additions & 2 deletions tiled/commandline/_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def serve_directory(
from ..alembic_utils import stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import initialize_database
from ..utils import ensure_specified_sql_driver

engine = create_async_engine(database)
engine = create_async_engine(ensure_specified_sql_driver(database))
asyncio.run(initialize_database(engine))
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)

Expand Down Expand Up @@ -389,8 +390,9 @@ def serve_catalog(
from ..alembic_utils import stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import initialize_database
from ..utils import ensure_specified_sql_driver

engine = create_async_engine(database)
engine = create_async_engine(ensure_specified_sql_driver(database))
asyncio.run(initialize_database(engine))
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)

Expand Down
27 changes: 27 additions & 0 deletions tiled/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,33 @@ def ensure_uri(uri_or_path) -> str:
return str(uri_str)


SCHEME_TO_SCHEME_PLUS_DRIVER = {
"postgresql": "postgresql+asyncpg",
"sqlite": "sqlite+aiosqlite",
}


def ensure_specified_sql_driver(uri: str) -> str:
"""
Given a URI without a driver in the scheme, add Tiled's preferred driver.
If a driver is already specified, the specified one will be used; it
will NOT be overriden by this function.
'postgresql://...' -> 'postgresql+asynpg://...'
'sqlite://...' -> 'sqlite+aiosqlite://...'
'postgresql+asyncpg://...' -> 'postgresql+asynpg://...'
'postgresql+my_custom_driver://...' -> 'postgresql+my_custom_driver://...'
'/path/to/file.db' -> 'sqlite+aiosqlite:////path/to/file.db'
"""
if not SCHEME_PATTERN.match(str(uri)):
# Interpret URI as filepath.
uri = f"sqlite+aiosqlite:///{Path(uri)}"
scheme, rest = uri.split(":", 1)
new_scheme = SCHEME_TO_SCHEME_PLUS_DRIVER.get(scheme, scheme)
return ":".join([new_scheme, rest])


class catch_warning_msg(warnings.catch_warnings):
"""Backward compatible version of catch_warnings for python <3.11.
Expand Down

0 comments on commit e52bab3

Please sign in to comment.