-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5730b46
commit 741c5b7
Showing
3 changed files
with
233 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__all__ = ["bulk_ops"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
"""Bulk database operations for performance. | ||
Provides a bulk_upsert function | ||
""" | ||
from typing import Any, Sequence | ||
|
||
import psycopg | ||
from psycopg import rows, sql | ||
|
||
Connection = psycopg.Connection | ||
Cursor = psycopg.Cursor | ||
kwargs_row = rows.kwargs_row | ||
|
||
|
||
def bulk_upsert( | ||
cur: psycopg.Cursor, | ||
table: str, | ||
attributes: Sequence[str], | ||
objects: Sequence[Any], | ||
constraint: str, | ||
update_condition: sql.SQL = sql.SQL(""), | ||
Check warning on line 21 in app/src/db/bulk_ops.py GitHub Actions / Lint
|
||
): | ||
"""Bulk insert or update a sequence of objects. | ||
Insert a sequence of objects, or update on conflict. | ||
Write data from one table to another. | ||
If there are conflicts due to unique constraints, overwrite existing data. | ||
Args: | ||
cur: the Cursor object from the pyscopg library | ||
table: the name of the table to insert into or update | ||
attributes: a sequence of attribute names to copy from each object | ||
objects: a sequence of objects to upsert | ||
constraint: the table unique constraint to use to determine conflicts | ||
update_condition: optional WHERE clause to limit updates for a | ||
conflicting row | ||
""" | ||
temp_table = f"temp_{table}" | ||
create_temp_table(cur, temp_table=temp_table, src_table=table) | ||
bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) | ||
write_from_table_to_table( | ||
cur, | ||
src_table=temp_table, | ||
dest_table=table, | ||
columns=attributes, | ||
constraint=constraint, | ||
update_condition=update_condition, | ||
) | ||
|
||
|
||
def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str): | ||
""" | ||
Create table that lives only for the current transaction. | ||
Use an existing table to determine the table structure. | ||
Once the transaction is committed the temp table will be deleted. | ||
Args: | ||
temp_table: the name of the temporary table to create | ||
src_table: the name of the existing table | ||
""" | ||
cur.execute( | ||
sql.SQL( | ||
"CREATE TEMP TABLE {temp_table}\ | ||
(LIKE {src_table})\ | ||
ON COMMIT DROP" | ||
).format( | ||
temp_table=sql.Identifier(temp_table), | ||
src_table=sql.Identifier(src_table), | ||
) | ||
) | ||
|
||
|
||
def bulk_insert( | ||
cur: psycopg.Cursor, | ||
table: str, | ||
columns: Sequence[str], | ||
objects: Sequence[Any], | ||
): | ||
""" | ||
Write data from a sequence of objects to a temp table. | ||
This function uses the PostgreSQL COPY command which is highly performant. | ||
Args: | ||
cur: the Cursor object from the pyscopg library | ||
table: the name of the temporary table | ||
columns: a sequence of column names that are attributes of each object | ||
objects: a sequence of objects with attributes defined by columns | ||
""" | ||
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) | ||
query = sql.SQL("COPY {table}({columns}) FROM STDIN").format( | ||
table=sql.Identifier(table), | ||
columns=columns_sql, | ||
) | ||
with cur.copy(query) as copy: | ||
for obj in objects: | ||
values = [getattr(obj, column) for column in columns] | ||
copy.write_row(values) | ||
|
||
|
||
def write_from_table_to_table( | ||
cur: psycopg.Cursor, | ||
src_table: str, | ||
dest_table: str, | ||
columns: Sequence[str], | ||
constraint: str, | ||
update_condition: sql.SQL = sql.SQL(""), | ||
Check warning on line 104 in app/src/db/bulk_ops.py GitHub Actions / Lint
|
||
): | ||
""" | ||
Write data from one table to another. | ||
If there are conflicts due to unique constraints, overwrite existing data. | ||
Args: | ||
cur: the Cursor object from the pyscopg library | ||
src_table: the name of the table that will be copied from | ||
dest_table: the name of the table that will be written to | ||
columns: a sequence of column names to copy over | ||
constraint: the arbiter constraint to use to determine conflicts | ||
update_condition: optional WHERE clause to limit updates for a | ||
conflicting row | ||
""" | ||
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) | ||
update_sql = sql.SQL(",").join( | ||
[ | ||
sql.SQL("{column} = EXCLUDED.{column}").format( | ||
column=sql.Identifier(column), | ||
) | ||
for column in columns | ||
] | ||
) | ||
query = sql.SQL( | ||
"INSERT INTO {dest_table}({columns})\ | ||
SELECT {columns} FROM {src_table}\ | ||
ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\ | ||
{update_condition}" | ||
).format( | ||
dest_table=sql.Identifier(dest_table), | ||
columns=columns_sql, | ||
src_table=sql.Identifier(src_table), | ||
constraint=sql.Identifier(constraint), | ||
update_sql=update_sql, | ||
update_condition=update_condition, | ||
) | ||
cur.execute(query) | ||
|
||
|
||
__all__ = ["bulk_upsert"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
"""Tests for bulk_ops module""" | ||
import operator | ||
import random | ||
from dataclasses import dataclass | ||
|
||
from psycopg import rows, sql | ||
|
||
import src.adapters.db as db | ||
from src.db import bulk_ops | ||
|
||
|
||
def test_bulk_upsert(db_session: db.Session): | ||
conn = db_session.connection().connection | ||
# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager | ||
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore | ||
table = "temp_table" | ||
attributes = ["id", "num"] | ||
objects = [get_random_number_object() for i in range(100)] | ||
constraint = "temp_table_pkey" | ||
|
||
# Create a table for testing bulk upsert | ||
cur.execute( | ||
sql.SQL( | ||
"CREATE TEMP TABLE {table}" | ||
"(" | ||
"id TEXT NOT NULL," | ||
"num INT," | ||
"CONSTRAINT {constraint} PRIMARY KEY (id)" | ||
")" | ||
).format( | ||
table=sql.Identifier(table), | ||
constraint=sql.Identifier(constraint), | ||
) | ||
) | ||
|
||
bulk_ops.bulk_upsert( | ||
cur, | ||
table, | ||
attributes, | ||
objects, | ||
constraint, | ||
) | ||
conn.commit() | ||
|
||
# Check that all the objects were inserted | ||
cur.execute( | ||
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( | ||
table=sql.Identifier(table) | ||
) | ||
) | ||
records = cur.fetchall() | ||
objects.sort(key=operator.attrgetter("id")) | ||
assert records == objects | ||
|
||
# Now modify half of the objects | ||
for obj in objects[: int(len(objects) / 2)]: | ||
obj.num = random.randint(1, 10000) | ||
|
||
bulk_ops.bulk_upsert( | ||
cur, | ||
table, | ||
attributes, | ||
objects, | ||
constraint, | ||
) | ||
conn.commit() | ||
|
||
# Check that the objects were updated | ||
cur.execute( | ||
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format( | ||
table=sql.Identifier(table) | ||
) | ||
) | ||
records = cur.fetchall() | ||
objects.sort(key=operator.attrgetter("id")) | ||
assert records == objects | ||
|
||
|
||
@dataclass | ||
class Number: | ||
id: str | ||
num: int | ||
|
||
|
||
def get_random_number_object() -> Number: | ||
return Number( | ||
id=str(random.randint(1000000, 9999999)), | ||
num=random.randint(1, 10000), | ||
) |