Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to SQLAlchemy 2.0 #197

Merged
merged 10 commits into from
Sep 20, 2023
135 changes: 67 additions & 68 deletions app/poetry.lock

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions app/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = ["Nava Engineering <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.10"
SQLAlchemy = {extras = ["mypy"], version = "^1.4.40"}
SQLAlchemy = {extras = ["mypy"], version = "2.0"}
alembic = "^1.8.1"
psycopg2-binary = "^2.9.3"
python-dotenv = "^0.20.0"
Expand Down Expand Up @@ -37,6 +37,7 @@ bandit = "^1.7.4"
pytest = "^6.0.0"
pytest-watch = "^4.2.0"
pytest-lazy-fixture = "^0.6.3"
types-pyyaml = "^6.0.12.11"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down Expand Up @@ -80,8 +81,6 @@ warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true

plugins = ["sqlalchemy.ext.mypy.plugin"]

[tool.bandit]
# Ignore audit logging test file since test audit logging requires a lot of operations that trigger bandit warnings
exclude_dirs = ["./tests/src/logging/test_audit.py"]
Expand Down
3 changes: 0 additions & 3 deletions app/src/adapters/db/clients/postgres_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ def get_conn() -> Any:
return sqlalchemy.create_engine(
"postgresql://",
pool=conn_pool,
# FYI, execute many mode handles how SQLAlchemy handles doing a bunch of inserts/updates/deletes at once
# https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#psycopg2-fast-execution-helpers
executemany_mode="batch",
Comment on lines -49 to -51
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batch mode no longer exists: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#psycopg2-fast-execution-helpers

There's another way to configure it to do a similar operation, but I wasn't sure if we wanted to keep it

hide_parameters=db_config.hide_sql_parameter_logs,
# TODO: Don't think we need this as we aren't using JSON columns, but keeping for reference
# json_serializer=lambda o: json.dumps(o, default=pydantic.json.pydantic_encoder),
Expand Down
2 changes: 1 addition & 1 deletion app/src/db/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def include_object(
object: sqlalchemy.schema.SchemaItem,
name: str,
name: str | None,
type_: str,
reflected: bool,
compare_to: Any,
Expand Down
41 changes: 0 additions & 41 deletions app/src/db/migrations/run.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Convenience script for running alembic migration commands through a pyscript
# rather than the command line. This allows poetry to package and alias it for
# running on the production docker image from any directory.
import itertools
import logging
import os
from typing import Optional

import alembic.command as command
import alembic.script as script
import sqlalchemy
from alembic.config import Config
from alembic.operations.ops import MigrationScript
from alembic.runtime import migration

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,41 +50,3 @@ def have_all_migrations_run(db_engine: sqlalchemy.engine.Engine) -> None:
logger.info(
f"The current migration head is up to date, {current_heads} and Alembic is expecting {expected_heads}"
)


def check_model_parity() -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't used and was giving me some type errors. There's a better way to do this anyways as Alembic has a utility now for checking the model parity for you which I'll create a ticket and add later.

revisions: list[MigrationScript] = []

def process_revision_directives(
context: migration.MigrationContext,
revision: Optional[str],
directives: list[MigrationScript],
) -> None:
nonlocal revisions
revisions = list(directives)
# Prevent actually generating a migration
directives[:] = []

command.revision(
config=alembic_cfg,
autogenerate=True,
process_revision_directives=process_revision_directives,
)
diff = list(
itertools.chain.from_iterable(
op.as_diffs() for script in revisions for op in script.upgrade_ops_list
)
)

message = (
"The application models are not in sync with the migrations. You should generate "
"a new automigration or update your local migration file. "
"If there are unexpected errors you may need to merge main into your branch."
)

if diff:
for line in diff:
print("::error title=Missing migration::Missing migration:", line)

logger.error(message, extra={"issues": str(diff)})
raise Exception(message)
50 changes: 35 additions & 15 deletions app/src/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from typing import Any
from uuid import UUID

from sqlalchemy import TIMESTAMP, Column, MetaData, inspect
from sqlalchemy import TIMESTAMP, MetaData, Text, inspect
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.declarative import as_declarative
from sqlalchemy.orm import declarative_mixin
from sqlalchemy.orm import DeclarativeBase, Mapped, declarative_mixin, mapped_column
from sqlalchemy.sql.functions import now as sqlnow

from src.util import datetime_util
Expand All @@ -26,10 +25,33 @@
)


@as_declarative(metadata=metadata)
class Base:
class Base(DeclarativeBase):
# Attach the metadata to the Base class so all tables automatically get added to the metadata
metadata = metadata

# Override the default type that SQLAlchemy will map python types to.
# This is used if you simply define a column like:
#
# my_column: Mapped[str]
#
# If you provide a mapped_column attribute you can override these values
#
# See: https://docs.sqlalchemy.org/en/20/orm/declarative_tables.html#mapped-column-derives-the-datatype-and-nullability-from-the-mapped-annotation
# for the default mappings
#
# See: https://docs.sqlalchemy.org/en/20/orm/declarative_tables.html#orm-declarative-mapped-column-type-map
# for details on setting up this configuration.
type_annotation_map = {
lorenyu marked this conversation as resolved.
Show resolved Hide resolved
# Explicitly use the Text column type for strings
str: Text,
lorenyu marked this conversation as resolved.
Show resolved Hide resolved
# Always include a timezone for datetimes
datetime: TIMESTAMP(timezone=True),
lorenyu marked this conversation as resolved.
Show resolved Hide resolved
# Always use the Postgres UUID column type
uuid.UUID: postgresql.UUID(as_uuid=True),
}

def _dict(self) -> dict:
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs} # type: ignore

def for_json(self) -> dict:
json_valid_dict = {}
Expand All @@ -46,9 +68,9 @@ def for_json(self) -> dict:

def copy(self, **kwargs: dict[str, Any]) -> "Base":
# TODO - Python 3.11 will let us make the return Self instead
table = self.__table__ # type: ignore
table = self.__table__
non_pk_columns = [
k for k in table.columns.keys() if k not in table.primary_key.columns.keys()
k for k in table.columns.keys() if k not in table.primary_key.columns.keys() # type: ignore
]
data = {c: getattr(self, c) for c in non_pk_columns}
data.update(kwargs)
Expand All @@ -59,10 +81,10 @@ def copy(self, **kwargs: dict[str, Any]) -> "Base":
@declarative_mixin
class IdMixin:
"""Mixin to add a UUID id primary key column to a model
https://docs.sqlalchemy.org/en/14/orm/declarative_mixins.html
https://docs.sqlalchemy.org/en/20/orm/declarative_mixins.html
"""

id: uuid.UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)


def same_as_created_at(context: Any) -> Any:
Expand All @@ -72,18 +94,16 @@ def same_as_created_at(context: Any) -> Any:
@declarative_mixin
class TimestampMixin:
"""Mixin to add created_at and updated_at columns to a model
https://docs.sqlalchemy.org/en/14/orm/declarative_mixins.html#mixing-in-columns
https://docs.sqlalchemy.org/en/20/orm/declarative_mixins.html#mixing-in-columns
"""

created_at: datetime = Column(
TIMESTAMP(timezone=True),
created_at: Mapped[datetime] = mapped_column(
nullable=False,
default=datetime_util.utcnow,
server_default=sqlnow(),
)

updated_at: datetime = Column(
TIMESTAMP(timezone=True),
updated_at: Mapped[datetime] = mapped_column(
nullable=False,
default=same_as_created_at,
onupdate=datetime_util.utcnow,
Expand Down
29 changes: 14 additions & 15 deletions app/src/db/models/user_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from typing import Optional
from uuid import UUID

from sqlalchemy import Boolean, Column, Date, Enum, ForeignKey, Text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, relationship
from sqlalchemy import Enum, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship

from src.db.models.base import Base, IdMixin, TimestampMixin

Expand All @@ -21,22 +20,22 @@ class RoleType(str, enum.Enum):
class User(Base, IdMixin, TimestampMixin):
__tablename__ = "user"

first_name: str = Column(Text, nullable=False)
middle_name: Optional[str] = Column(Text)
last_name: str = Column(Text, nullable=False)
phone_number: str = Column(Text, nullable=False)
date_of_birth: date = Column(Date, nullable=False)
is_active: bool = Column(Boolean, nullable=False)
first_name: Mapped[str]
middle_name: Mapped[Optional[str]]
last_name: Mapped[str]
phone_number: Mapped[str]
date_of_birth: Mapped[date]
is_active: Mapped[bool]

roles: list["Role"] = relationship(
roles: Mapped[list["Role"]] = relationship(
"Role", back_populates="user", cascade="all, delete", order_by="Role.type"
)


class Role(Base, TimestampMixin):
__tablename__ = "role"
user_id: Mapped[UUID] = Column(
postgresql.UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)

# Set native_enum=False to use store enum values as VARCHAR/TEXT
Expand All @@ -48,6 +47,6 @@ class Role(Base, TimestampMixin):
# not yet functional
# (See https://github.com/sqlalchemy/alembic/issues/363)
#
# https://docs.sqlalchemy.org/en/14/core/type_basics.html#sqlalchemy.types.Enum.params.native_enum
type: RoleType = Column(Enum(RoleType, native_enum=False), primary_key=True)
user: User = relationship(User, back_populates="roles")
# https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Enum.params.native_enum
type: Mapped[RoleType] = mapped_column(Enum(RoleType, native_enum=False), primary_key=True)
user: Mapped[User] = relationship(User, back_populates="roles")
2 changes: 1 addition & 1 deletion app/src/services/users/get_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# https://github.com/navapbc/template-application-flask/issues/52
def get_user(db_session: Session, user_id: str) -> User:
# TODO: move this to service and/or persistence layer
result = db_session.query(User).options(orm.selectinload(User.roles)).get(user_id)
result = db_session.get(User, user_id, options=[orm.selectinload(User.roles)])

if result is None:
# TODO move HTTP related logic out of service layer to controller layer and just return None from here
Expand Down
2 changes: 1 addition & 1 deletion app/src/services/users/patch_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def patch_user(

with db_session.begin():
# TODO: move this to service and/or persistence layer
user = db_session.query(User).options(orm.selectinload(User.roles)).get(user_id)
user = db_session.get(User, user_id, options=[orm.selectinload(User.roles)])

if user is None:
# TODO move HTTP related logic out of service layer to controller layer and just return None from here
Expand Down
3 changes: 2 additions & 1 deletion app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def db_client(monkeypatch_session) -> db.DBClient:
"""

with db_testing.create_isolated_db(monkeypatch_session) as db_client:
models.metadata.create_all(bind=db_client.get_connection())
with db_client.get_connection() as conn, conn.begin():
models.metadata.create_all(bind=conn)
yield db_client


Expand Down
19 changes: 14 additions & 5 deletions app/tests/lib/db_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
import uuid

from sqlalchemy import text

import src.adapters.db as db
from src.adapters.db.clients.postgres_config import get_db_config

Expand All @@ -25,21 +27,28 @@ def create_isolated_db(monkeypatch) -> db.DBClient:
db_client = db.PostgresDBClient()
with db_client.get_connection() as conn:
_create_schema(conn, schema_name)
try:
yield db_client
finally:

try:
yield db_client

finally:
with db_client.get_connection() as conn:
lorenyu marked this conversation as resolved.
Show resolved Hide resolved
_drop_schema(conn, schema_name)


def _create_schema(conn: db.Connection, schema_name: str):
"""Create a database schema."""
db_test_user = get_db_config().username

conn.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name} AUTHORIZATION {db_test_user};")
with conn.begin():
conn.execute(
text(f"CREATE SCHEMA IF NOT EXISTS {schema_name} AUTHORIZATION {db_test_user};")
)
logger.info("create schema %s", schema_name)


def _drop_schema(conn: db.Connection, schema_name: str):
"""Drop a database schema."""
conn.execute(f"DROP SCHEMA {schema_name} CASCADE;")
with conn.begin():
conn.execute(text(f"DROP SCHEMA {schema_name} CASCADE;"))
logger.info("drop schema %s", schema_name)
2 changes: 1 addition & 1 deletion app/tests/src/db/models/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_db_session() -> db.Session:

# The scopefunc ensures that the session gets cleaned up after each test
# it implicitly calls `remove()` on the session.
# see https://docs.sqlalchemy.org/en/14/orm/contextual.html
# see https://docs.sqlalchemy.org/en/20/orm/contextual.html
Session = scoped_session(lambda: get_db_session(), scopefunc=lambda: get_db_session())


Expand Down