Skip to content

Commit

Permalink
Address reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinJBoyer committed Jun 6, 2024
1 parent 3ddb27d commit fd15ef9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 25 deletions.
5 changes: 5 additions & 0 deletions app/src/adapters/db/clients/postgres_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def check_db_connection(self) -> None:
# if check_migrations_current:
# have_all_migrations_run(engine)

def get_raw_connection(self) -> sqlalchemy.PoolProxiedConnection:
# For low-level operations not supported by SQLAlchemy.
# Unless you specifically need this, you should use get_connection().
return self._engine.raw_connection()


def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]:
connect_args: dict[str, Any] = {}
Expand Down
21 changes: 11 additions & 10 deletions app/src/db/bulk_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Bulk database operations for performance.
Provides a bulk_upsert function
Provides a bulk_upsert function for use with
Postgres and the psycopg library.
"""
from typing import Any, Optional, Sequence
from typing import Any, Sequence

import psycopg
from psycopg import rows, sql
Expand All @@ -18,7 +19,7 @@ def bulk_upsert(
attributes: Sequence[str],
objects: Sequence[Any],
constraint: str,
update_condition: Optional[sql.SQL] = None,
update_condition: sql.SQL | None = None,
) -> None:
"""Bulk insert or update a sequence of objects.
Expand All @@ -39,9 +40,9 @@ def bulk_upsert(
update_condition = sql.SQL("")

temp_table = f"temp_{table}"
create_temp_table(cur, temp_table=temp_table, src_table=table)
bulk_insert(cur, table=temp_table, columns=attributes, objects=objects)
write_from_table_to_table(
_create_temp_table(cur, temp_table=temp_table, src_table=table)
_bulk_insert(cur, table=temp_table, columns=attributes, objects=objects)
_write_from_table_to_table(
cur,
src_table=temp_table,
dest_table=table,
Expand All @@ -51,7 +52,7 @@ def bulk_upsert(
)


def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None:
def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None:
"""
Create table that lives only for the current transaction.
Use an existing table to determine the table structure.
Expand All @@ -72,7 +73,7 @@ def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> N
)


def bulk_insert(
def _bulk_insert(
cur: psycopg.Cursor,
table: str,
columns: Sequence[str],
Expand All @@ -98,13 +99,13 @@ def bulk_insert(
copy.write_row(values)


def write_from_table_to_table(
def _write_from_table_to_table(
cur: psycopg.Cursor,
src_table: str,
dest_table: str,
columns: Sequence[str],
constraint: str,
update_condition: Optional[sql.SQL] = None,
update_condition: sql.SQL | None = None,
) -> None:
"""
Write data from one table to another.
Expand Down
35 changes: 20 additions & 15 deletions app/tests/src/db/test_bulk_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,23 @@
from src.db import bulk_ops


@dataclass
class Number:
id: str
num: int


def get_random_number_object() -> Number:
return Number(
id=str(random.randint(1000000, 9999999)),
num=random.randint(1, 10000),
)


def test_bulk_upsert(db_session: db.Session):
conn = db_session.connection().connection
db_client = db.PostgresDBClient()
conn = db_client.get_raw_connection()
# conn = db_session.connection().connection
# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore
table = "temp_table"
Expand Down Expand Up @@ -56,6 +71,9 @@ def test_bulk_upsert(db_session: db.Session):
for obj in objects[: int(len(objects) / 2)]:
obj.num = random.randint(1, 10000)

# And insert additional objects
objects.extend([get_random_number_object() for i in range(50)])

bulk_ops.bulk_upsert(
cur,
table,
Expand All @@ -65,7 +83,7 @@ def test_bulk_upsert(db_session: db.Session):
)
conn.commit()

# Check that the objects were updated
# Check that the existing objects were updated and new objects were inserted
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
Expand All @@ -74,16 +92,3 @@ def test_bulk_upsert(db_session: db.Session):
records = cur.fetchall()
objects.sort(key=operator.attrgetter("id"))
assert records == objects


@dataclass
class Number:
id: str
num: int


def get_random_number_object() -> Number:
return Number(
id=str(random.randint(1000000, 9999999)),
num=random.randint(1, 10000),
)

0 comments on commit fd15ef9

Please sign in to comment.