diff --git a/app/src/adapters/db/clients/postgres_client.py b/app/src/adapters/db/clients/postgres_client.py index c70ceef..332a537 100644 --- a/app/src/adapters/db/clients/postgres_client.py +++ b/app/src/adapters/db/clients/postgres_client.py @@ -74,6 +74,11 @@ def check_db_connection(self) -> None: # if check_migrations_current: # have_all_migrations_run(engine) + def get_raw_connection(self) -> sqlalchemy.PoolProxiedConnection: + # For low-level operations not supported by SQLAlchemy. + # Unless you specifically need this, you should use get_connection(). + return self._engine.raw_connection() + def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]: connect_args: dict[str, Any] = {} diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py index 1d7bd2a..fa10005 100644 --- a/app/src/db/bulk_ops.py +++ b/app/src/db/bulk_ops.py @@ -1,8 +1,9 @@ """Bulk database operations for performance. -Provides a bulk_upsert function +Provides a bulk_upsert function for use with +Postgres and the psycopg library. """ -from typing import Any, Optional, Sequence +from typing import Any, Sequence import psycopg from psycopg import rows, sql @@ -18,7 +19,7 @@ def bulk_upsert( attributes: Sequence[str], objects: Sequence[Any], constraint: str, - update_condition: Optional[sql.SQL] = None, + update_condition: sql.SQL | None = None, ) -> None: """Bulk insert or update a sequence of objects. @@ -39,9 +40,9 @@ def bulk_upsert( update_condition = sql.SQL("") temp_table = f"temp_{table}" - create_temp_table(cur, temp_table=temp_table, src_table=table) - bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) - write_from_table_to_table( + _create_temp_table(cur, temp_table=temp_table, src_table=table) + _bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) + _write_from_table_to_table( cur, src_table=temp_table, dest_table=table, @@ -51,7 +52,7 @@ def bulk_upsert( ) -def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None: +def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None: """ Create table that lives only for the current transaction. Use an existing table to determine the table structure. @@ -72,7 +73,7 @@ def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> N ) -def bulk_insert( +def _bulk_insert( cur: psycopg.Cursor, table: str, columns: Sequence[str], @@ -98,13 +99,13 @@ def bulk_insert( copy.write_row(values) -def write_from_table_to_table( +def _write_from_table_to_table( cur: psycopg.Cursor, src_table: str, dest_table: str, columns: Sequence[str], constraint: str, - update_condition: Optional[sql.SQL] = None, + update_condition: sql.SQL | None = None, ) -> None: """ Write data from one table to another. diff --git a/app/tests/src/db/test_bulk_ops.py b/app/tests/src/db/test_bulk_ops.py index 57d6826..b6ac40e 100644 --- a/app/tests/src/db/test_bulk_ops.py +++ b/app/tests/src/db/test_bulk_ops.py @@ -9,8 +9,23 @@ from src.db import bulk_ops +@dataclass +class Number: + id: str + num: int + + +def get_random_number_object() -> Number: + return Number( + id=str(random.randint(1000000, 9999999)), + num=random.randint(1, 10000), + ) + + def test_bulk_upsert(db_session: db.Session): - conn = db_session.connection().connection + db_client = db.PostgresDBClient() + conn = db_client.get_raw_connection() + # conn = db_session.connection().connection # Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore table = "temp_table" @@ -56,6 +71,9 @@ def test_bulk_upsert(db_session: db.Session): for obj in objects[: int(len(objects) / 2)]: obj.num = random.randint(1, 10000) + # And insert additional objects + objects.extend([get_random_number_object() for i in range(50)]) + bulk_ops.bulk_upsert( cur, table, @@ -65,7 +83,7 @@ def test_bulk_upsert(db_session: db.Session): ) conn.commit() - # Check that the objects were updated + # Check that the existing objects were updated and new objects were inserted cur.execute( sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( table=sql.Identifier(table) @@ -74,16 +92,3 @@ def test_bulk_upsert(db_session: db.Session): records = cur.fetchall() objects.sort(key=operator.attrgetter("id")) assert records == objects - - -@dataclass -class Number: - id: str - num: int - - -def get_random_number_object() -> Number: - return Number( - id=str(random.randint(1000000, 9999999)), - num=random.randint(1, 10000), - )