Skip to content

Commit

Permalink
Add bulk operations utilities (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinJBoyer authored Jun 10, 2024
1 parent cbcc8d8 commit 0f5619c
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 0 deletions.
5 changes: 5 additions & 0 deletions app/src/adapters/db/clients/postgres_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def check_db_connection(self) -> None:
# if check_migrations_current:
# have_all_migrations_run(engine)

def get_raw_connection(self) -> sqlalchemy.PoolProxiedConnection:
# For low-level operations not supported by SQLAlchemy.
# Unless you specifically need this, you should use get_connection().
return self._engine.raw_connection()


def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]:
connect_args: dict[str, Any] = {}
Expand Down
1 change: 1 addition & 0 deletions app/src/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ["bulk_ops"]
151 changes: 151 additions & 0 deletions app/src/db/bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Bulk database operations for performance.
Provides a bulk_upsert function for use with
Postgres and the psycopg library.
"""
from typing import Any, Sequence

import psycopg
from psycopg import rows, sql

Connection = psycopg.Connection
Cursor = psycopg.Cursor
kwargs_row = rows.kwargs_row


def bulk_upsert(
cur: psycopg.Cursor,
table: str,
attributes: Sequence[str],
objects: Sequence[Any],
constraint: str,
update_condition: sql.SQL | None = None,
) -> None:
"""Bulk insert or update a sequence of objects.
Insert a sequence of objects, or update on conflict.
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.
Args:
cur: the Cursor object from the pyscopg library
table: the name of the table to insert into or update
attributes: a sequence of attribute names to copy from each object
objects: a sequence of objects to upsert
constraint: the table unique constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
if not update_condition:
update_condition = sql.SQL("")

temp_table = f"temp_{table}"
_create_temp_table(cur, temp_table=temp_table, src_table=table)
_bulk_insert(cur, table=temp_table, columns=attributes, objects=objects)
_write_from_table_to_table(
cur,
src_table=temp_table,
dest_table=table,
columns=attributes,
constraint=constraint,
update_condition=update_condition,
)


def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None:
"""
Create table that lives only for the current transaction.
Use an existing table to determine the table structure.
Once the transaction is committed the temp table will be deleted.
Args:
temp_table: the name of the temporary table to create
src_table: the name of the existing table
"""
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {temp_table}\
(LIKE {src_table})\
ON COMMIT DROP"
).format(
temp_table=sql.Identifier(temp_table),
src_table=sql.Identifier(src_table),
)
)


def _bulk_insert(
cur: psycopg.Cursor,
table: str,
columns: Sequence[str],
objects: Sequence[Any],
) -> None:
"""
Write data from a sequence of objects to a temp table.
This function uses the PostgreSQL COPY command which is highly performant.
Args:
cur: the Cursor object from the pyscopg library
table: the name of the temporary table
columns: a sequence of column names that are attributes of each object
objects: a sequence of objects with attributes defined by columns
"""
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
query = sql.SQL("COPY {table}({columns}) FROM STDIN").format(
table=sql.Identifier(table),
columns=columns_sql,
)
with cur.copy(query) as copy:
for obj in objects:
values = [getattr(obj, column) for column in columns]
copy.write_row(values)


def _write_from_table_to_table(
cur: psycopg.Cursor,
src_table: str,
dest_table: str,
columns: Sequence[str],
constraint: str,
update_condition: sql.SQL | None = None,
) -> None:
"""
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.
Args:
cur: the Cursor object from the pyscopg library
src_table: the name of the table that will be copied from
dest_table: the name of the table that will be written to
columns: a sequence of column names to copy over
constraint: the arbiter constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
if not update_condition:
update_condition = sql.SQL("")

columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
update_sql = sql.SQL(",").join(
[
sql.SQL("{column} = EXCLUDED.{column}").format(
column=sql.Identifier(column),
)
for column in columns
if column not in ["id", "number"]
]
)
query = sql.SQL(
"INSERT INTO {dest_table}({columns})\
SELECT {columns} FROM {src_table}\
ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\
{update_condition}"
).format(
dest_table=sql.Identifier(dest_table),
columns=columns_sql,
src_table=sql.Identifier(src_table),
constraint=sql.Identifier(constraint),
update_sql=update_sql,
update_condition=update_condition,
)
cur.execute(query)


__all__ = ["bulk_upsert"]
101 changes: 101 additions & 0 deletions app/tests/src/db/test_bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Tests for bulk_ops module"""
import operator
import random
from dataclasses import dataclass

from psycopg import rows, sql

import src.adapters.db as db
from src.db import bulk_ops


@dataclass
class Number:
id: str
num: int


def get_random_number_object() -> Number:
return Number(
id=str(random.randint(1000000, 9999999)),
num=random.randint(1, 10000),
)


def test_bulk_upsert(db_session: db.Session):
db_client = db.PostgresDBClient()
conn = db_client.get_raw_connection()

# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore
table = "temp_table"
attributes = ["id", "num"]
objects = [get_random_number_object() for i in range(100)]
constraint = "temp_table_pkey"

# Create a table for testing bulk upsert
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {table}"
"("
"id TEXT NOT NULL,"
"num INT,"
"CONSTRAINT {constraint} PRIMARY KEY (id)"
")"
).format(
table=sql.Identifier(table),
constraint=sql.Identifier(constraint),
)
)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
objects,
constraint,
)
conn.commit()

# Check that all the objects were inserted
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
objects.sort(key=operator.attrgetter("id"))
assert records == objects

# Now modify half of the objects
updated_indexes = random.sample(range(100), 50)
original_objects = [objects[i] for i in range(100) if i not in updated_indexes]
updated_objects = [objects[i] for i in updated_indexes]
for obj in updated_objects:
obj.num = random.randint(1, 10000)

# And insert additional objects
inserted_objects = [get_random_number_object() for i in range(50)]

updated_and_inserted_objects = updated_objects + inserted_objects
random.shuffle(updated_objects + inserted_objects)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
updated_and_inserted_objects,
constraint,
)
conn.commit()

# Check that the existing objects were updated and new objects were inserted
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
expected_objects = original_objects + updated_and_inserted_objects
expected_objects.sort(key=operator.attrgetter("id"))
assert records == expected_objects

0 comments on commit 0f5619c

Please sign in to comment.