Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bulk operations utilities #224

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions app/src/adapters/db/clients/postgres_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def check_db_connection(self) -> None:
# if check_migrations_current:
# have_all_migrations_run(engine)

def get_raw_connection(self) -> sqlalchemy.PoolProxiedConnection:
# For low-level operations not supported by SQLAlchemy.
# Unless you specifically need this, you should use get_connection().
return self._engine.raw_connection()


def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]:
connect_args: dict[str, Any] = {}
Expand Down
1 change: 1 addition & 0 deletions app/src/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ["bulk_ops"]
151 changes: 151 additions & 0 deletions app/src/db/bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Bulk database operations for performance.

Provides a bulk_upsert function for use with
Postgres and the psycopg library.
"""
KevinJBoyer marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, Sequence

import psycopg
lorenyu marked this conversation as resolved.
Show resolved Hide resolved
from psycopg import rows, sql

Connection = psycopg.Connection
Cursor = psycopg.Cursor
kwargs_row = rows.kwargs_row


def bulk_upsert(
cur: psycopg.Cursor,
table: str,
attributes: Sequence[str],
objects: Sequence[Any],
constraint: str,
update_condition: sql.SQL | None = None,
) -> None:
"""Bulk insert or update a sequence of objects.

Insert a sequence of objects, or update on conflict.
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.

Args:
cur: the Cursor object from the pyscopg library
table: the name of the table to insert into or update
attributes: a sequence of attribute names to copy from each object
objects: a sequence of objects to upsert
constraint: the table unique constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
if not update_condition:
update_condition = sql.SQL("")

temp_table = f"temp_{table}"
_create_temp_table(cur, temp_table=temp_table, src_table=table)
_bulk_insert(cur, table=temp_table, columns=attributes, objects=objects)
_write_from_table_to_table(
cur,
src_table=temp_table,
dest_table=table,
columns=attributes,
constraint=constraint,
update_condition=update_condition,
)


def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None:
"""
Create table that lives only for the current transaction.
Use an existing table to determine the table structure.
Once the transaction is committed the temp table will be deleted.
Args:
temp_table: the name of the temporary table to create
src_table: the name of the existing table
"""
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {temp_table}\
(LIKE {src_table})\
ON COMMIT DROP"
).format(
temp_table=sql.Identifier(temp_table),
src_table=sql.Identifier(src_table),
)
)


def _bulk_insert(
cur: psycopg.Cursor,
table: str,
columns: Sequence[str],
objects: Sequence[Any],
) -> None:
"""
Write data from a sequence of objects to a temp table.
This function uses the PostgreSQL COPY command which is highly performant.
Args:
cur: the Cursor object from the pyscopg library
table: the name of the temporary table
columns: a sequence of column names that are attributes of each object
objects: a sequence of objects with attributes defined by columns
"""
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
query = sql.SQL("COPY {table}({columns}) FROM STDIN").format(
table=sql.Identifier(table),
columns=columns_sql,
)
with cur.copy(query) as copy:
for obj in objects:
values = [getattr(obj, column) for column in columns]
copy.write_row(values)


def _write_from_table_to_table(
cur: psycopg.Cursor,
src_table: str,
dest_table: str,
columns: Sequence[str],
constraint: str,
update_condition: sql.SQL | None = None,
) -> None:
"""
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.
Args:
cur: the Cursor object from the pyscopg library
src_table: the name of the table that will be copied from
dest_table: the name of the table that will be written to
columns: a sequence of column names to copy over
constraint: the arbiter constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
if not update_condition:
update_condition = sql.SQL("")

columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
update_sql = sql.SQL(",").join(
[
sql.SQL("{column} = EXCLUDED.{column}").format(
column=sql.Identifier(column),
)
for column in columns
if column not in ["id", "number"]
]
)
query = sql.SQL(
"INSERT INTO {dest_table}({columns})\
SELECT {columns} FROM {src_table}\
ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\
{update_condition}"
).format(
dest_table=sql.Identifier(dest_table),
columns=columns_sql,
src_table=sql.Identifier(src_table),
constraint=sql.Identifier(constraint),
update_sql=update_sql,
update_condition=update_condition,
)
cur.execute(query)


__all__ = ["bulk_upsert"]
101 changes: 101 additions & 0 deletions app/tests/src/db/test_bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Tests for bulk_ops module"""
import operator
import random
from dataclasses import dataclass

from psycopg import rows, sql

import src.adapters.db as db
from src.db import bulk_ops


@dataclass
class Number:
id: str
num: int


def get_random_number_object() -> Number:
return Number(
id=str(random.randint(1000000, 9999999)),
num=random.randint(1, 10000),
)


def test_bulk_upsert(db_session: db.Session):
db_client = db.PostgresDBClient()
conn = db_client.get_raw_connection()

# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore
table = "temp_table"
attributes = ["id", "num"]
objects = [get_random_number_object() for i in range(100)]
constraint = "temp_table_pkey"

# Create a table for testing bulk upsert
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {table}"
"("
"id TEXT NOT NULL,"
"num INT,"
"CONSTRAINT {constraint} PRIMARY KEY (id)"
")"
).format(
table=sql.Identifier(table),
constraint=sql.Identifier(constraint),
)
)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
objects,
constraint,
)
conn.commit()

# Check that all the objects were inserted
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
objects.sort(key=operator.attrgetter("id"))
assert records == objects

# Now modify half of the objects
updated_indexes = random.sample(range(100), 50)
original_objects = [objects[i] for i in range(100) if i not in updated_indexes]
updated_objects = [objects[i] for i in updated_indexes]
for obj in updated_objects:
obj.num = random.randint(1, 10000)

# And insert additional objects
inserted_objects = [get_random_number_object() for i in range(50)]

updated_and_inserted_objects = updated_objects + inserted_objects
random.shuffle(updated_objects + inserted_objects)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
updated_and_inserted_objects,
constraint,
)
conn.commit()

# Check that the existing objects were updated and new objects were inserted
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
expected_objects = original_objects + updated_and_inserted_objects
expected_objects.sort(key=operator.attrgetter("id"))
assert records == expected_objects
Loading