diff --git a/app/src/adapters/db/clients/postgres_client.py b/app/src/adapters/db/clients/postgres_client.py index c70ceef..332a537 100644 --- a/app/src/adapters/db/clients/postgres_client.py +++ b/app/src/adapters/db/clients/postgres_client.py @@ -74,6 +74,11 @@ def check_db_connection(self) -> None: # if check_migrations_current: # have_all_migrations_run(engine) + def get_raw_connection(self) -> sqlalchemy.PoolProxiedConnection: + # For low-level operations not supported by SQLAlchemy. + # Unless you specifically need this, you should use get_connection(). + return self._engine.raw_connection() + def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]: connect_args: dict[str, Any] = {} diff --git a/app/src/db/__init__.py b/app/src/db/__init__.py index e69de29..0af7720 100644 --- a/app/src/db/__init__.py +++ b/app/src/db/__init__.py @@ -0,0 +1 @@ +__all__ = ["bulk_ops"] diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py new file mode 100644 index 0000000..fa10005 --- /dev/null +++ b/app/src/db/bulk_ops.py @@ -0,0 +1,151 @@ +"""Bulk database operations for performance. + +Provides a bulk_upsert function for use with +Postgres and the psycopg library. +""" +from typing import Any, Sequence + +import psycopg +from psycopg import rows, sql + +Connection = psycopg.Connection +Cursor = psycopg.Cursor +kwargs_row = rows.kwargs_row + + +def bulk_upsert( + cur: psycopg.Cursor, + table: str, + attributes: Sequence[str], + objects: Sequence[Any], + constraint: str, + update_condition: sql.SQL | None = None, +) -> None: + """Bulk insert or update a sequence of objects. + + Insert a sequence of objects, or update on conflict. + Write data from one table to another. + If there are conflicts due to unique constraints, overwrite existing data. + + Args: + cur: the Cursor object from the pyscopg library + table: the name of the table to insert into or update + attributes: a sequence of attribute names to copy from each object + objects: a sequence of objects to upsert + constraint: the table unique constraint to use to determine conflicts + update_condition: optional WHERE clause to limit updates for a + conflicting row + """ + if not update_condition: + update_condition = sql.SQL("") + + temp_table = f"temp_{table}" + _create_temp_table(cur, temp_table=temp_table, src_table=table) + _bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) + _write_from_table_to_table( + cur, + src_table=temp_table, + dest_table=table, + columns=attributes, + constraint=constraint, + update_condition=update_condition, + ) + + +def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None: + """ + Create table that lives only for the current transaction. + Use an existing table to determine the table structure. + Once the transaction is committed the temp table will be deleted. + Args: + temp_table: the name of the temporary table to create + src_table: the name of the existing table + """ + cur.execute( + sql.SQL( + "CREATE TEMP TABLE {temp_table}\ + (LIKE {src_table})\ + ON COMMIT DROP" + ).format( + temp_table=sql.Identifier(temp_table), + src_table=sql.Identifier(src_table), + ) + ) + + +def _bulk_insert( + cur: psycopg.Cursor, + table: str, + columns: Sequence[str], + objects: Sequence[Any], +) -> None: + """ + Write data from a sequence of objects to a temp table. + This function uses the PostgreSQL COPY command which is highly performant. + Args: + cur: the Cursor object from the pyscopg library + table: the name of the temporary table + columns: a sequence of column names that are attributes of each object + objects: a sequence of objects with attributes defined by columns + """ + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) + query = sql.SQL("COPY {table}({columns}) FROM STDIN").format( + table=sql.Identifier(table), + columns=columns_sql, + ) + with cur.copy(query) as copy: + for obj in objects: + values = [getattr(obj, column) for column in columns] + copy.write_row(values) + + +def _write_from_table_to_table( + cur: psycopg.Cursor, + src_table: str, + dest_table: str, + columns: Sequence[str], + constraint: str, + update_condition: sql.SQL | None = None, +) -> None: + """ + Write data from one table to another. + If there are conflicts due to unique constraints, overwrite existing data. + Args: + cur: the Cursor object from the pyscopg library + src_table: the name of the table that will be copied from + dest_table: the name of the table that will be written to + columns: a sequence of column names to copy over + constraint: the arbiter constraint to use to determine conflicts + update_condition: optional WHERE clause to limit updates for a + conflicting row + """ + if not update_condition: + update_condition = sql.SQL("") + + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) + update_sql = sql.SQL(",").join( + [ + sql.SQL("{column} = EXCLUDED.{column}").format( + column=sql.Identifier(column), + ) + for column in columns + if column not in ["id", "number"] + ] + ) + query = sql.SQL( + "INSERT INTO {dest_table}({columns})\ + SELECT {columns} FROM {src_table}\ + ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\ + {update_condition}" + ).format( + dest_table=sql.Identifier(dest_table), + columns=columns_sql, + src_table=sql.Identifier(src_table), + constraint=sql.Identifier(constraint), + update_sql=update_sql, + update_condition=update_condition, + ) + cur.execute(query) + + +__all__ = ["bulk_upsert"] diff --git a/app/tests/src/db/test_bulk_ops.py b/app/tests/src/db/test_bulk_ops.py new file mode 100644 index 0000000..8b72d14 --- /dev/null +++ b/app/tests/src/db/test_bulk_ops.py @@ -0,0 +1,101 @@ +"""Tests for bulk_ops module""" +import operator +import random +from dataclasses import dataclass + +from psycopg import rows, sql + +import src.adapters.db as db +from src.db import bulk_ops + + +@dataclass +class Number: + id: str + num: int + + +def get_random_number_object() -> Number: + return Number( + id=str(random.randint(1000000, 9999999)), + num=random.randint(1, 10000), + ) + + +def test_bulk_upsert(db_session: db.Session): + db_client = db.PostgresDBClient() + conn = db_client.get_raw_connection() + + # Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager + with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore + table = "temp_table" + attributes = ["id", "num"] + objects = [get_random_number_object() for i in range(100)] + constraint = "temp_table_pkey" + + # Create a table for testing bulk upsert + cur.execute( + sql.SQL( + "CREATE TEMP TABLE {table}" + "(" + "id TEXT NOT NULL," + "num INT," + "CONSTRAINT {constraint} PRIMARY KEY (id)" + ")" + ).format( + table=sql.Identifier(table), + constraint=sql.Identifier(constraint), + ) + ) + + bulk_ops.bulk_upsert( + cur, + table, + attributes, + objects, + constraint, + ) + conn.commit() + + # Check that all the objects were inserted + cur.execute( + sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( + table=sql.Identifier(table) + ) + ) + records = cur.fetchall() + objects.sort(key=operator.attrgetter("id")) + assert records == objects + + # Now modify half of the objects + updated_indexes = random.sample(range(100), 50) + original_objects = [objects[i] for i in range(100) if i not in updated_indexes] + updated_objects = [objects[i] for i in updated_indexes] + for obj in updated_objects: + obj.num = random.randint(1, 10000) + + # And insert additional objects + inserted_objects = [get_random_number_object() for i in range(50)] + + updated_and_inserted_objects = updated_objects + inserted_objects + random.shuffle(updated_objects + inserted_objects) + + bulk_ops.bulk_upsert( + cur, + table, + attributes, + updated_and_inserted_objects, + constraint, + ) + conn.commit() + + # Check that the existing objects were updated and new objects were inserted + cur.execute( + sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( + table=sql.Identifier(table) + ) + ) + records = cur.fetchall() + expected_objects = original_objects + updated_and_inserted_objects + expected_objects.sort(key=operator.attrgetter("id")) + assert records == expected_objects