From 741c5b7c25a889443d2225885d8eb14c9d167c18 Mon Sep 17 00:00:00 2001 From: Kevin Boyer Date: Wed, 8 May 2024 10:16:23 -0400 Subject: [PATCH 1/5] Add bulk_ops --- app/src/db/__init__.py | 1 + app/src/db/bulk_ops.py | 143 ++++++++++++++++++++++++++++++ app/tests/src/db/test_bulk_ops.py | 89 +++++++++++++++++++ 3 files changed, 233 insertions(+) create mode 100644 app/src/db/bulk_ops.py create mode 100644 app/tests/src/db/test_bulk_ops.py diff --git a/app/src/db/__init__.py b/app/src/db/__init__.py index e69de29b..0af77203 100644 --- a/app/src/db/__init__.py +++ b/app/src/db/__init__.py @@ -0,0 +1 @@ +__all__ = ["bulk_ops"] diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py new file mode 100644 index 00000000..df818f45 --- /dev/null +++ b/app/src/db/bulk_ops.py @@ -0,0 +1,143 @@ +"""Bulk database operations for performance. + +Provides a bulk_upsert function +""" +from typing import Any, Sequence + +import psycopg +from psycopg import rows, sql + +Connection = psycopg.Connection +Cursor = psycopg.Cursor +kwargs_row = rows.kwargs_row + + +def bulk_upsert( + cur: psycopg.Cursor, + table: str, + attributes: Sequence[str], + objects: Sequence[Any], + constraint: str, + update_condition: sql.SQL = sql.SQL(""), +): + """Bulk insert or update a sequence of objects. + + Insert a sequence of objects, or update on conflict. + Write data from one table to another. + If there are conflicts due to unique constraints, overwrite existing data. + + Args: + cur: the Cursor object from the pyscopg library + table: the name of the table to insert into or update + attributes: a sequence of attribute names to copy from each object + objects: a sequence of objects to upsert + constraint: the table unique constraint to use to determine conflicts + update_condition: optional WHERE clause to limit updates for a + conflicting row + """ + temp_table = f"temp_{table}" + create_temp_table(cur, temp_table=temp_table, src_table=table) + bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) + write_from_table_to_table( + cur, + src_table=temp_table, + dest_table=table, + columns=attributes, + constraint=constraint, + update_condition=update_condition, + ) + + +def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str): + """ + Create table that lives only for the current transaction. + Use an existing table to determine the table structure. + Once the transaction is committed the temp table will be deleted. + Args: + temp_table: the name of the temporary table to create + src_table: the name of the existing table + """ + cur.execute( + sql.SQL( + "CREATE TEMP TABLE {temp_table}\ + (LIKE {src_table})\ + ON COMMIT DROP" + ).format( + temp_table=sql.Identifier(temp_table), + src_table=sql.Identifier(src_table), + ) + ) + + +def bulk_insert( + cur: psycopg.Cursor, + table: str, + columns: Sequence[str], + objects: Sequence[Any], +): + """ + Write data from a sequence of objects to a temp table. + This function uses the PostgreSQL COPY command which is highly performant. + Args: + cur: the Cursor object from the pyscopg library + table: the name of the temporary table + columns: a sequence of column names that are attributes of each object + objects: a sequence of objects with attributes defined by columns + """ + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) + query = sql.SQL("COPY {table}({columns}) FROM STDIN").format( + table=sql.Identifier(table), + columns=columns_sql, + ) + with cur.copy(query) as copy: + for obj in objects: + values = [getattr(obj, column) for column in columns] + copy.write_row(values) + + +def write_from_table_to_table( + cur: psycopg.Cursor, + src_table: str, + dest_table: str, + columns: Sequence[str], + constraint: str, + update_condition: sql.SQL = sql.SQL(""), +): + """ + Write data from one table to another. + If there are conflicts due to unique constraints, overwrite existing data. + Args: + cur: the Cursor object from the pyscopg library + src_table: the name of the table that will be copied from + dest_table: the name of the table that will be written to + columns: a sequence of column names to copy over + constraint: the arbiter constraint to use to determine conflicts + update_condition: optional WHERE clause to limit updates for a + conflicting row + """ + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) + update_sql = sql.SQL(",").join( + [ + sql.SQL("{column} = EXCLUDED.{column}").format( + column=sql.Identifier(column), + ) + for column in columns + ] + ) + query = sql.SQL( + "INSERT INTO {dest_table}({columns})\ + SELECT {columns} FROM {src_table}\ + ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\ + {update_condition}" + ).format( + dest_table=sql.Identifier(dest_table), + columns=columns_sql, + src_table=sql.Identifier(src_table), + constraint=sql.Identifier(constraint), + update_sql=update_sql, + update_condition=update_condition, + ) + cur.execute(query) + + +__all__ = ["bulk_upsert"] diff --git a/app/tests/src/db/test_bulk_ops.py b/app/tests/src/db/test_bulk_ops.py new file mode 100644 index 00000000..57d68263 --- /dev/null +++ b/app/tests/src/db/test_bulk_ops.py @@ -0,0 +1,89 @@ +"""Tests for bulk_ops module""" +import operator +import random +from dataclasses import dataclass + +from psycopg import rows, sql + +import src.adapters.db as db +from src.db import bulk_ops + + +def test_bulk_upsert(db_session: db.Session): + conn = db_session.connection().connection + # Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager + with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore + table = "temp_table" + attributes = ["id", "num"] + objects = [get_random_number_object() for i in range(100)] + constraint = "temp_table_pkey" + + # Create a table for testing bulk upsert + cur.execute( + sql.SQL( + "CREATE TEMP TABLE {table}" + "(" + "id TEXT NOT NULL," + "num INT," + "CONSTRAINT {constraint} PRIMARY KEY (id)" + ")" + ).format( + table=sql.Identifier(table), + constraint=sql.Identifier(constraint), + ) + ) + + bulk_ops.bulk_upsert( + cur, + table, + attributes, + objects, + constraint, + ) + conn.commit() + + # Check that all the objects were inserted + cur.execute( + sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( + table=sql.Identifier(table) + ) + ) + records = cur.fetchall() + objects.sort(key=operator.attrgetter("id")) + assert records == objects + + # Now modify half of the objects + for obj in objects[: int(len(objects) / 2)]: + obj.num = random.randint(1, 10000) + + bulk_ops.bulk_upsert( + cur, + table, + attributes, + objects, + constraint, + ) + conn.commit() + + # Check that the objects were updated + cur.execute( + sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( + table=sql.Identifier(table) + ) + ) + records = cur.fetchall() + objects.sort(key=operator.attrgetter("id")) + assert records == objects + + +@dataclass +class Number: + id: str + num: int + + +def get_random_number_object() -> Number: + return Number( + id=str(random.randint(1000000, 9999999)), + num=random.randint(1, 10000), + ) From 272da42299256daac7d4319ca5e15de96b5535bf Mon Sep 17 00:00:00 2001 From: Kevin Boyer Date: Wed, 8 May 2024 10:25:23 -0400 Subject: [PATCH 2/5] Add return type and move function call out of default parameter --- app/src/db/bulk_ops.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py index df818f45..8c6c6e78 100644 --- a/app/src/db/bulk_ops.py +++ b/app/src/db/bulk_ops.py @@ -11,6 +11,8 @@ Cursor = psycopg.Cursor kwargs_row = rows.kwargs_row +EMPTY_SQL = sql.SQL("") + def bulk_upsert( cur: psycopg.Cursor, @@ -18,8 +20,8 @@ def bulk_upsert( attributes: Sequence[str], objects: Sequence[Any], constraint: str, - update_condition: sql.SQL = sql.SQL(""), -): + update_condition: sql.SQL = EMPTY_SQL, +) -> None: """Bulk insert or update a sequence of objects. Insert a sequence of objects, or update on conflict. @@ -48,7 +50,7 @@ def bulk_upsert( ) -def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str): +def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None: """ Create table that lives only for the current transaction. Use an existing table to determine the table structure. @@ -74,7 +76,7 @@ def bulk_insert( table: str, columns: Sequence[str], objects: Sequence[Any], -): +) -> None: """ Write data from a sequence of objects to a temp table. This function uses the PostgreSQL COPY command which is highly performant. @@ -101,8 +103,8 @@ def write_from_table_to_table( dest_table: str, columns: Sequence[str], constraint: str, - update_condition: sql.SQL = sql.SQL(""), -): + update_condition: sql.SQL = EMPTY_SQL, +) -> None: """ Write data from one table to another. If there are conflicts due to unique constraints, overwrite existing data. @@ -122,6 +124,7 @@ def write_from_table_to_table( column=sql.Identifier(column), ) for column in columns + if column not in ["id", "number"] ] ) query = sql.SQL( From 3ddb27dbf4770e92ceed0677639ae8a338f6b0b4 Mon Sep 17 00:00:00 2001 From: Kevin Boyer Date: Wed, 8 May 2024 10:30:22 -0400 Subject: [PATCH 3/5] Set type of update_condition to Optional --- app/src/db/bulk_ops.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py index 8c6c6e78..1d7bd2a1 100644 --- a/app/src/db/bulk_ops.py +++ b/app/src/db/bulk_ops.py @@ -2,7 +2,7 @@ Provides a bulk_upsert function """ -from typing import Any, Sequence +from typing import Any, Optional, Sequence import psycopg from psycopg import rows, sql @@ -11,8 +11,6 @@ Cursor = psycopg.Cursor kwargs_row = rows.kwargs_row -EMPTY_SQL = sql.SQL("") - def bulk_upsert( cur: psycopg.Cursor, @@ -20,7 +18,7 @@ def bulk_upsert( attributes: Sequence[str], objects: Sequence[Any], constraint: str, - update_condition: sql.SQL = EMPTY_SQL, + update_condition: Optional[sql.SQL] = None, ) -> None: """Bulk insert or update a sequence of objects. @@ -37,6 +35,9 @@ def bulk_upsert( update_condition: optional WHERE clause to limit updates for a conflicting row """ + if not update_condition: + update_condition = sql.SQL("") + temp_table = f"temp_{table}" create_temp_table(cur, temp_table=temp_table, src_table=table) bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) @@ -103,7 +104,7 @@ def write_from_table_to_table( dest_table: str, columns: Sequence[str], constraint: str, - update_condition: sql.SQL = EMPTY_SQL, + update_condition: Optional[sql.SQL] = None, ) -> None: """ Write data from one table to another. @@ -117,6 +118,9 @@ def write_from_table_to_table( update_condition: optional WHERE clause to limit updates for a conflicting row """ + if not update_condition: + update_condition = sql.SQL("") + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) update_sql = sql.SQL(",").join( [ From fd15ef902a50f761657f316241d695b766ff7ad9 Mon Sep 17 00:00:00 2001 From: Kevin Boyer Date: Thu, 6 Jun 2024 13:36:01 -0400 Subject: [PATCH 4/5] Address reviewer comments --- .../adapters/db/clients/postgres_client.py | 5 +++ app/src/db/bulk_ops.py | 21 +++++------ app/tests/src/db/test_bulk_ops.py | 35 +++++++++++-------- 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/app/src/adapters/db/clients/postgres_client.py b/app/src/adapters/db/clients/postgres_client.py index c70ceefa..332a537a 100644 --- a/app/src/adapters/db/clients/postgres_client.py +++ b/app/src/adapters/db/clients/postgres_client.py @@ -74,6 +74,11 @@ def check_db_connection(self) -> None: # if check_migrations_current: # have_all_migrations_run(engine) + def get_raw_connection(self) -> sqlalchemy.PoolProxiedConnection: + # For low-level operations not supported by SQLAlchemy. + # Unless you specifically need this, you should use get_connection(). + return self._engine.raw_connection() + def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]: connect_args: dict[str, Any] = {} diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py index 1d7bd2a1..fa10005d 100644 --- a/app/src/db/bulk_ops.py +++ b/app/src/db/bulk_ops.py @@ -1,8 +1,9 @@ """Bulk database operations for performance. -Provides a bulk_upsert function +Provides a bulk_upsert function for use with +Postgres and the psycopg library. """ -from typing import Any, Optional, Sequence +from typing import Any, Sequence import psycopg from psycopg import rows, sql @@ -18,7 +19,7 @@ def bulk_upsert( attributes: Sequence[str], objects: Sequence[Any], constraint: str, - update_condition: Optional[sql.SQL] = None, + update_condition: sql.SQL | None = None, ) -> None: """Bulk insert or update a sequence of objects. @@ -39,9 +40,9 @@ def bulk_upsert( update_condition = sql.SQL("") temp_table = f"temp_{table}" - create_temp_table(cur, temp_table=temp_table, src_table=table) - bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) - write_from_table_to_table( + _create_temp_table(cur, temp_table=temp_table, src_table=table) + _bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) + _write_from_table_to_table( cur, src_table=temp_table, dest_table=table, @@ -51,7 +52,7 @@ def bulk_upsert( ) -def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None: +def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None: """ Create table that lives only for the current transaction. Use an existing table to determine the table structure. @@ -72,7 +73,7 @@ def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> N ) -def bulk_insert( +def _bulk_insert( cur: psycopg.Cursor, table: str, columns: Sequence[str], @@ -98,13 +99,13 @@ def bulk_insert( copy.write_row(values) -def write_from_table_to_table( +def _write_from_table_to_table( cur: psycopg.Cursor, src_table: str, dest_table: str, columns: Sequence[str], constraint: str, - update_condition: Optional[sql.SQL] = None, + update_condition: sql.SQL | None = None, ) -> None: """ Write data from one table to another. diff --git a/app/tests/src/db/test_bulk_ops.py b/app/tests/src/db/test_bulk_ops.py index 57d68263..b6ac40e5 100644 --- a/app/tests/src/db/test_bulk_ops.py +++ b/app/tests/src/db/test_bulk_ops.py @@ -9,8 +9,23 @@ from src.db import bulk_ops +@dataclass +class Number: + id: str + num: int + + +def get_random_number_object() -> Number: + return Number( + id=str(random.randint(1000000, 9999999)), + num=random.randint(1, 10000), + ) + + def test_bulk_upsert(db_session: db.Session): - conn = db_session.connection().connection + db_client = db.PostgresDBClient() + conn = db_client.get_raw_connection() + # conn = db_session.connection().connection # Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore table = "temp_table" @@ -56,6 +71,9 @@ def test_bulk_upsert(db_session: db.Session): for obj in objects[: int(len(objects) / 2)]: obj.num = random.randint(1, 10000) + # And insert additional objects + objects.extend([get_random_number_object() for i in range(50)]) + bulk_ops.bulk_upsert( cur, table, @@ -65,7 +83,7 @@ def test_bulk_upsert(db_session: db.Session): ) conn.commit() - # Check that the objects were updated + # Check that the existing objects were updated and new objects were inserted cur.execute( sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( table=sql.Identifier(table) @@ -74,16 +92,3 @@ def test_bulk_upsert(db_session: db.Session): records = cur.fetchall() objects.sort(key=operator.attrgetter("id")) assert records == objects - - -@dataclass -class Number: - id: str - num: int - - -def get_random_number_object() -> Number: - return Number( - id=str(random.randint(1000000, 9999999)), - num=random.randint(1, 10000), - ) From 87e5e33b305acfbd7557d319c92511813d75fc7c Mon Sep 17 00:00:00 2001 From: Kevin Boyer Date: Mon, 10 Jun 2024 15:42:44 -0400 Subject: [PATCH 5/5] Update test --- app/tests/src/db/test_bulk_ops.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/app/tests/src/db/test_bulk_ops.py b/app/tests/src/db/test_bulk_ops.py index b6ac40e5..8b72d144 100644 --- a/app/tests/src/db/test_bulk_ops.py +++ b/app/tests/src/db/test_bulk_ops.py @@ -25,7 +25,7 @@ def get_random_number_object() -> Number: def test_bulk_upsert(db_session: db.Session): db_client = db.PostgresDBClient() conn = db_client.get_raw_connection() - # conn = db_session.connection().connection + # Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore table = "temp_table" @@ -68,17 +68,23 @@ def test_bulk_upsert(db_session: db.Session): assert records == objects # Now modify half of the objects - for obj in objects[: int(len(objects) / 2)]: + updated_indexes = random.sample(range(100), 50) + original_objects = [objects[i] for i in range(100) if i not in updated_indexes] + updated_objects = [objects[i] for i in updated_indexes] + for obj in updated_objects: obj.num = random.randint(1, 10000) # And insert additional objects - objects.extend([get_random_number_object() for i in range(50)]) + inserted_objects = [get_random_number_object() for i in range(50)] + + updated_and_inserted_objects = updated_objects + inserted_objects + random.shuffle(updated_objects + inserted_objects) bulk_ops.bulk_upsert( cur, table, attributes, - objects, + updated_and_inserted_objects, constraint, ) conn.commit() @@ -90,5 +96,6 @@ def test_bulk_upsert(db_session: db.Session): ) ) records = cur.fetchall() - objects.sort(key=operator.attrgetter("id")) - assert records == objects + expected_objects = original_objects + updated_and_inserted_objects + expected_objects.sort(key=operator.attrgetter("id")) + assert records == expected_objects