diff --git a/app/src/db/bulk_ops.py b/app/src/db/bulk_ops.py index 8c6c6e7..1d7bd2a 100644 --- a/app/src/db/bulk_ops.py +++ b/app/src/db/bulk_ops.py @@ -2,7 +2,7 @@ Provides a bulk_upsert function """ -from typing import Any, Sequence +from typing import Any, Optional, Sequence import psycopg from psycopg import rows, sql @@ -11,8 +11,6 @@ Cursor = psycopg.Cursor kwargs_row = rows.kwargs_row -EMPTY_SQL = sql.SQL("") - def bulk_upsert( cur: psycopg.Cursor, @@ -20,7 +18,7 @@ def bulk_upsert( attributes: Sequence[str], objects: Sequence[Any], constraint: str, - update_condition: sql.SQL = EMPTY_SQL, + update_condition: Optional[sql.SQL] = None, ) -> None: """Bulk insert or update a sequence of objects. @@ -37,6 +35,9 @@ def bulk_upsert( update_condition: optional WHERE clause to limit updates for a conflicting row """ + if not update_condition: + update_condition = sql.SQL("") + temp_table = f"temp_{table}" create_temp_table(cur, temp_table=temp_table, src_table=table) bulk_insert(cur, table=temp_table, columns=attributes, objects=objects) @@ -103,7 +104,7 @@ def write_from_table_to_table( dest_table: str, columns: Sequence[str], constraint: str, - update_condition: sql.SQL = EMPTY_SQL, + update_condition: Optional[sql.SQL] = None, ) -> None: """ Write data from one table to another. @@ -117,6 +118,9 @@ def write_from_table_to_table( update_condition: optional WHERE clause to limit updates for a conflicting row """ + if not update_condition: + update_condition = sql.SQL("") + columns_sql = sql.SQL(",").join(map(sql.Identifier, columns)) update_sql = sql.SQL(",").join( [