-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add bulk operations utilities (#224)
- Loading branch information
1 parent
cbcc8d8
commit 0f5619c
Showing
4 changed files
with
258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__all__ = ["bulk_ops"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
"""Bulk database operations for performance. | ||
Provides a bulk_upsert function for use with | ||
Postgres and the psycopg library. | ||
""" | ||
from typing import Any, Sequence | ||
|
||
import psycopg | ||
from psycopg import rows, sql | ||
|
||
Connection = psycopg.Connection | ||
Cursor = psycopg.Cursor | ||
kwargs_row = rows.kwargs_row | ||
|
||
|
||
def bulk_upsert( | ||
cur: psycopg.Cursor, | ||
table: str, | ||
attributes: Sequence[str], | ||
objects: Sequence[Any], | ||
constraint: str, | ||
update_condition: sql.SQL | None = None, | ||
) -> None: | ||
"""Bulk insert or update a sequence of objects. | ||
Insert a sequence of objects, or update on conflict. | ||
Write data from one table to another. | ||
If there are conflicts due to unique constraints, overwrite existing data. | ||
Args: | ||
cur: the Cursor object from the pyscopg library | ||
table: the name of the table to insert into or update | ||
attributes: a sequence of attribute names to copy from each object | ||
objects: a sequence of objects to upsert | ||
constraint: the table unique constraint to use to determine conflicts | ||
update_condition: optional WHERE clause to limit updates for a | ||
conflicting row | ||
""" | ||
if not update_condition: | ||
update_condition = sql.SQL("") | ||
|
||
temp_table = f"temp_{table}" | ||
_create_temp_table(cur, temp_table=temp_table, src_table=table) | ||
_bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) | ||
_write_from_table_to_table( | ||
cur, | ||
src_table=temp_table, | ||
dest_table=table, | ||
columns=attributes, | ||
constraint=constraint, | ||
update_condition=update_condition, | ||
) | ||
|
||
|
||
def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None: | ||
""" | ||
Create table that lives only for the current transaction. | ||
Use an existing table to determine the table structure. | ||
Once the transaction is committed the temp table will be deleted. | ||
Args: | ||
temp_table: the name of the temporary table to create | ||
src_table: the name of the existing table | ||
""" | ||
cur.execute( | ||
sql.SQL( | ||
"CREATE TEMP TABLE {temp_table}\ | ||
(LIKE {src_table})\ | ||
ON COMMIT DROP" | ||
).format( | ||
temp_table=sql.Identifier(temp_table), | ||
src_table=sql.Identifier(src_table), | ||
) | ||
) | ||
|
||
|
||
def _bulk_insert( | ||
cur: psycopg.Cursor, | ||
table: str, | ||
columns: Sequence[str], | ||
objects: Sequence[Any], | ||
) -> None: | ||
""" | ||
Write data from a sequence of objects to a temp table. | ||
This function uses the PostgreSQL COPY command which is highly performant. | ||
Args: | ||
cur: the Cursor object from the pyscopg library | ||
table: the name of the temporary table | ||
columns: a sequence of column names that are attributes of each object | ||
objects: a sequence of objects with attributes defined by columns | ||
""" | ||
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) | ||
query = sql.SQL("COPY {table}({columns}) FROM STDIN").format( | ||
table=sql.Identifier(table), | ||
columns=columns_sql, | ||
) | ||
with cur.copy(query) as copy: | ||
for obj in objects: | ||
values = [getattr(obj, column) for column in columns] | ||
copy.write_row(values) | ||
|
||
|
||
def _write_from_table_to_table( | ||
cur: psycopg.Cursor, | ||
src_table: str, | ||
dest_table: str, | ||
columns: Sequence[str], | ||
constraint: str, | ||
update_condition: sql.SQL | None = None, | ||
) -> None: | ||
""" | ||
Write data from one table to another. | ||
If there are conflicts due to unique constraints, overwrite existing data. | ||
Args: | ||
cur: the Cursor object from the pyscopg library | ||
src_table: the name of the table that will be copied from | ||
dest_table: the name of the table that will be written to | ||
columns: a sequence of column names to copy over | ||
constraint: the arbiter constraint to use to determine conflicts | ||
update_condition: optional WHERE clause to limit updates for a | ||
conflicting row | ||
""" | ||
if not update_condition: | ||
update_condition = sql.SQL("") | ||
|
||
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) | ||
update_sql = sql.SQL(",").join( | ||
[ | ||
sql.SQL("{column} = EXCLUDED.{column}").format( | ||
column=sql.Identifier(column), | ||
) | ||
for column in columns | ||
if column not in ["id", "number"] | ||
] | ||
) | ||
query = sql.SQL( | ||
"INSERT INTO {dest_table}({columns})\ | ||
SELECT {columns} FROM {src_table}\ | ||
ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\ | ||
{update_condition}" | ||
).format( | ||
dest_table=sql.Identifier(dest_table), | ||
columns=columns_sql, | ||
src_table=sql.Identifier(src_table), | ||
constraint=sql.Identifier(constraint), | ||
update_sql=update_sql, | ||
update_condition=update_condition, | ||
) | ||
cur.execute(query) | ||
|
||
|
||
__all__ = ["bulk_upsert"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
"""Tests for bulk_ops module""" | ||
import operator | ||
import random | ||
from dataclasses import dataclass | ||
|
||
from psycopg import rows, sql | ||
|
||
import src.adapters.db as db | ||
from src.db import bulk_ops | ||
|
||
|
||
@dataclass | ||
class Number: | ||
id: str | ||
num: int | ||
|
||
|
||
def get_random_number_object() -> Number: | ||
return Number( | ||
id=str(random.randint(1000000, 9999999)), | ||
num=random.randint(1, 10000), | ||
) | ||
|
||
|
||
def test_bulk_upsert(db_session: db.Session): | ||
db_client = db.PostgresDBClient() | ||
conn = db_client.get_raw_connection() | ||
|
||
# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager | ||
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore | ||
table = "temp_table" | ||
attributes = ["id", "num"] | ||
objects = [get_random_number_object() for i in range(100)] | ||
constraint = "temp_table_pkey" | ||
|
||
# Create a table for testing bulk upsert | ||
cur.execute( | ||
sql.SQL( | ||
"CREATE TEMP TABLE {table}" | ||
"(" | ||
"id TEXT NOT NULL," | ||
"num INT," | ||
"CONSTRAINT {constraint} PRIMARY KEY (id)" | ||
")" | ||
).format( | ||
table=sql.Identifier(table), | ||
constraint=sql.Identifier(constraint), | ||
) | ||
) | ||
|
||
bulk_ops.bulk_upsert( | ||
cur, | ||
table, | ||
attributes, | ||
objects, | ||
constraint, | ||
) | ||
conn.commit() | ||
|
||
# Check that all the objects were inserted | ||
cur.execute( | ||
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( | ||
table=sql.Identifier(table) | ||
) | ||
) | ||
records = cur.fetchall() | ||
objects.sort(key=operator.attrgetter("id")) | ||
assert records == objects | ||
|
||
# Now modify half of the objects | ||
updated_indexes = random.sample(range(100), 50) | ||
original_objects = [objects[i] for i in range(100) if i not in updated_indexes] | ||
updated_objects = [objects[i] for i in updated_indexes] | ||
for obj in updated_objects: | ||
obj.num = random.randint(1, 10000) | ||
|
||
# And insert additional objects | ||
inserted_objects = [get_random_number_object() for i in range(50)] | ||
|
||
updated_and_inserted_objects = updated_objects + inserted_objects | ||
random.shuffle(updated_objects + inserted_objects) | ||
|
||
bulk_ops.bulk_upsert( | ||
cur, | ||
table, | ||
attributes, | ||
updated_and_inserted_objects, | ||
constraint, | ||
) | ||
conn.commit() | ||
|
||
# Check that the existing objects were updated and new objects were inserted | ||
cur.execute( | ||
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( | ||
table=sql.Identifier(table) | ||
) | ||
) | ||
records = cur.fetchall() | ||
expected_objects = original_objects + updated_and_inserted_objects | ||
expected_objects.sort(key=operator.attrgetter("id")) | ||
assert records == expected_objects |