From 741c5b7c25a889443d2225885d8eb14c9d167c18 Mon Sep 17 00:00:00 2001 From: Kevin Boyer Date: Wed, 8 May 2024 10:16:23 -0400 Subject: [PATCH] Add bulk_ops --- app/src/db/__init__.py | 1 + app/src/db/bulk_ops.py | 143 ++++++++++++++++++++++++++++++ app/tests/src/db/test_bulk_ops.py | 89 +++++++++++++++++++ 3 files changed, 233 insertions(+) create mode 100644 app/src/db/bulk_ops.py create mode 100644 app/tests/src/db/test_bulk_ops.py diff --git a/app/src/db/__init__.py b/app/src/db/__init__.py index e69de29b..0af77203 100644 --- a/app/src/db/__init__.py +++ b/app/src/db/__init__.py @@ -0,0 +1 @@ +__all__ = ["bulk_ops"] diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py new file mode 100644 index 00000000..df818f45 --- /dev/null +++ b/app/src/db/bulk_ops.py @@ -0,0 +1,143 @@ +"""Bulk database operations for performance. + +Provides a bulk_upsert function +""" +from typing import Any, Sequence + +import psycopg +from psycopg import rows, sql + +Connection = psycopg.Connection +Cursor = psycopg.Cursor +kwargs_row = rows.kwargs_row + + +def bulk_upsert( + cur: psycopg.Cursor, + table: str, + attributes: Sequence[str], + objects: Sequence[Any], + constraint: str, + update_condition: sql.SQL = sql.SQL(""), +): + """Bulk insert or update a sequence of objects. + + Insert a sequence of objects, or update on conflict. + Write data from one table to another. + If there are conflicts due to unique constraints, overwrite existing data. + + Args: + cur: the Cursor object from the pyscopg library + table: the name of the table to insert into or update + attributes: a sequence of attribute names to copy from each object + objects: a sequence of objects to upsert + constraint: the table unique constraint to use to determine conflicts + update_condition: optional WHERE clause to limit updates for a + conflicting row + """ + temp_table = f"temp_{table}" + create_temp_table(cur, temp_table=temp_table, src_table=table) + bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) + write_from_table_to_table( + cur, + src_table=temp_table, + dest_table=table, + columns=attributes, + constraint=constraint, + update_condition=update_condition, + ) + + +def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str): + """ + Create table that lives only for the current transaction. + Use an existing table to determine the table structure. + Once the transaction is committed the temp table will be deleted. + Args: + temp_table: the name of the temporary table to create + src_table: the name of the existing table + """ + cur.execute( + sql.SQL( + "CREATE TEMP TABLE {temp_table}\ + (LIKE {src_table})\ + ON COMMIT DROP" + ).format( + temp_table=sql.Identifier(temp_table), + src_table=sql.Identifier(src_table), + ) + ) + + +def bulk_insert( + cur: psycopg.Cursor, + table: str, + columns: Sequence[str], + objects: Sequence[Any], +): + """ + Write data from a sequence of objects to a temp table. + This function uses the PostgreSQL COPY command which is highly performant. + Args: + cur: the Cursor object from the pyscopg library + table: the name of the temporary table + columns: a sequence of column names that are attributes of each object + objects: a sequence of objects with attributes defined by columns + """ + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) + query = sql.SQL("COPY {table}({columns}) FROM STDIN").format( + table=sql.Identifier(table), + columns=columns_sql, + ) + with cur.copy(query) as copy: + for obj in objects: + values = [getattr(obj, column) for column in columns] + copy.write_row(values) + + +def write_from_table_to_table( + cur: psycopg.Cursor, + src_table: str, + dest_table: str, + columns: Sequence[str], + constraint: str, + update_condition: sql.SQL = sql.SQL(""), +): + """ + Write data from one table to another. + If there are conflicts due to unique constraints, overwrite existing data. + Args: + cur: the Cursor object from the pyscopg library + src_table: the name of the table that will be copied from + dest_table: the name of the table that will be written to + columns: a sequence of column names to copy over + constraint: the arbiter constraint to use to determine conflicts + update_condition: optional WHERE clause to limit updates for a + conflicting row + """ + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) + update_sql = sql.SQL(",").join( + [ + sql.SQL("{column} = EXCLUDED.{column}").format( + column=sql.Identifier(column), + ) + for column in columns + ] + ) + query = sql.SQL( + "INSERT INTO {dest_table}({columns})\ + SELECT {columns} FROM {src_table}\ + ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\ + {update_condition}" + ).format( + dest_table=sql.Identifier(dest_table), + columns=columns_sql, + src_table=sql.Identifier(src_table), + constraint=sql.Identifier(constraint), + update_sql=update_sql, + update_condition=update_condition, + ) + cur.execute(query) + + +__all__ = ["bulk_upsert"] diff --git a/app/tests/src/db/test_bulk_ops.py b/app/tests/src/db/test_bulk_ops.py new file mode 100644 index 00000000..57d68263 --- /dev/null +++ b/app/tests/src/db/test_bulk_ops.py @@ -0,0 +1,89 @@ +"""Tests for bulk_ops module""" +import operator +import random +from dataclasses import dataclass + +from psycopg import rows, sql + +import src.adapters.db as db +from src.db import bulk_ops + + +def test_bulk_upsert(db_session: db.Session): + conn = db_session.connection().connection + # Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager + with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore + table = "temp_table" + attributes = ["id", "num"] + objects = [get_random_number_object() for i in range(100)] + constraint = "temp_table_pkey" + + # Create a table for testing bulk upsert + cur.execute( + sql.SQL( + "CREATE TEMP TABLE {table}" + "(" + "id TEXT NOT NULL," + "num INT," + "CONSTRAINT {constraint} PRIMARY KEY (id)" + ")" + ).format( + table=sql.Identifier(table), + constraint=sql.Identifier(constraint), + ) + ) + + bulk_ops.bulk_upsert( + cur, + table, + attributes, + objects, + constraint, + ) + conn.commit() + + # Check that all the objects were inserted + cur.execute( + sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( + table=sql.Identifier(table) + ) + ) + records = cur.fetchall() + objects.sort(key=operator.attrgetter("id")) + assert records == objects + + # Now modify half of the objects + for obj in objects[: int(len(objects) / 2)]: + obj.num = random.randint(1, 10000) + + bulk_ops.bulk_upsert( + cur, + table, + attributes, + objects, + constraint, + ) + conn.commit() + + # Check that the objects were updated + cur.execute( + sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( + table=sql.Identifier(table) + ) + ) + records = cur.fetchall() + objects.sort(key=operator.attrgetter("id")) + assert records == objects + + +@dataclass +class Number: + id: str + num: int + + +def get_random_number_object() -> Number: + return Number( + id=str(random.randint(1000000, 9999999)), + num=random.randint(1, 10000), + )