Skip to content

Commit

Permalink
Add bulk_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinJBoyer committed May 8, 2024
1 parent 5730b46 commit 741c5b7
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 0 deletions.
1 change: 1 addition & 0 deletions app/src/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ["bulk_ops"]
143 changes: 143 additions & 0 deletions app/src/db/bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Bulk database operations for performance.
Provides a bulk_upsert function
"""
from typing import Any, Sequence

import psycopg
from psycopg import rows, sql

Connection = psycopg.Connection
Cursor = psycopg.Cursor
kwargs_row = rows.kwargs_row


def bulk_upsert(
cur: psycopg.Cursor,
table: str,
attributes: Sequence[str],
objects: Sequence[Any],
constraint: str,
update_condition: sql.SQL = sql.SQL(""),

Check warning on line 21 in app/src/db/bulk_ops.py

View workflow job for this annotation

GitHub Actions / Lint

src/db/bulk_ops.py:21:33: B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value.
):
"""Bulk insert or update a sequence of objects.
Insert a sequence of objects, or update on conflict.
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.
Args:
cur: the Cursor object from the pyscopg library
table: the name of the table to insert into or update
attributes: a sequence of attribute names to copy from each object
objects: a sequence of objects to upsert
constraint: the table unique constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
temp_table = f"temp_{table}"
create_temp_table(cur, temp_table=temp_table, src_table=table)
bulk_insert(cur, table=temp_table, columns=attributes, objects=objects)
write_from_table_to_table(
cur,
src_table=temp_table,
dest_table=table,
columns=attributes,
constraint=constraint,
update_condition=update_condition,
)


def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str):
"""
Create table that lives only for the current transaction.
Use an existing table to determine the table structure.
Once the transaction is committed the temp table will be deleted.
Args:
temp_table: the name of the temporary table to create
src_table: the name of the existing table
"""
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {temp_table}\
(LIKE {src_table})\
ON COMMIT DROP"
).format(
temp_table=sql.Identifier(temp_table),
src_table=sql.Identifier(src_table),
)
)


def bulk_insert(
cur: psycopg.Cursor,
table: str,
columns: Sequence[str],
objects: Sequence[Any],
):
"""
Write data from a sequence of objects to a temp table.
This function uses the PostgreSQL COPY command which is highly performant.
Args:
cur: the Cursor object from the pyscopg library
table: the name of the temporary table
columns: a sequence of column names that are attributes of each object
objects: a sequence of objects with attributes defined by columns
"""
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
query = sql.SQL("COPY {table}({columns}) FROM STDIN").format(
table=sql.Identifier(table),
columns=columns_sql,
)
with cur.copy(query) as copy:
for obj in objects:
values = [getattr(obj, column) for column in columns]
copy.write_row(values)


def write_from_table_to_table(
cur: psycopg.Cursor,
src_table: str,
dest_table: str,
columns: Sequence[str],
constraint: str,
update_condition: sql.SQL = sql.SQL(""),

Check warning on line 104 in app/src/db/bulk_ops.py

View workflow job for this annotation

GitHub Actions / Lint

src/db/bulk_ops.py:104:33: B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value.
):
"""
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.
Args:
cur: the Cursor object from the pyscopg library
src_table: the name of the table that will be copied from
dest_table: the name of the table that will be written to
columns: a sequence of column names to copy over
constraint: the arbiter constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
update_sql = sql.SQL(",").join(
[
sql.SQL("{column} = EXCLUDED.{column}").format(
column=sql.Identifier(column),
)
for column in columns
]
)
query = sql.SQL(
"INSERT INTO {dest_table}({columns})\
SELECT {columns} FROM {src_table}\
ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\
{update_condition}"
).format(
dest_table=sql.Identifier(dest_table),
columns=columns_sql,
src_table=sql.Identifier(src_table),
constraint=sql.Identifier(constraint),
update_sql=update_sql,
update_condition=update_condition,
)
cur.execute(query)


__all__ = ["bulk_upsert"]
89 changes: 89 additions & 0 deletions app/tests/src/db/test_bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Tests for bulk_ops module"""
import operator
import random
from dataclasses import dataclass

from psycopg import rows, sql

import src.adapters.db as db
from src.db import bulk_ops


def test_bulk_upsert(db_session: db.Session):
conn = db_session.connection().connection
# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore
table = "temp_table"
attributes = ["id", "num"]
objects = [get_random_number_object() for i in range(100)]
constraint = "temp_table_pkey"

# Create a table for testing bulk upsert
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {table}"
"("
"id TEXT NOT NULL,"
"num INT,"
"CONSTRAINT {constraint} PRIMARY KEY (id)"
")"
).format(
table=sql.Identifier(table),
constraint=sql.Identifier(constraint),
)
)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
objects,
constraint,
)
conn.commit()

# Check that all the objects were inserted
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
objects.sort(key=operator.attrgetter("id"))
assert records == objects

# Now modify half of the objects
for obj in objects[: int(len(objects) / 2)]:
obj.num = random.randint(1, 10000)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
objects,
constraint,
)
conn.commit()

# Check that the objects were updated
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
objects.sort(key=operator.attrgetter("id"))
assert records == objects


@dataclass
class Number:
id: str
num: int


def get_random_number_object() -> Number:
return Number(
id=str(random.randint(1000000, 9999999)),
num=random.randint(1, 10000),
)

0 comments on commit 741c5b7

Please sign in to comment.