diff --git a/lms/services/upsert.py b/lms/services/upsert.py index 8db15925e6..c02754e229 100644 --- a/lms/services/upsert.py +++ b/lms/services/upsert.py @@ -10,7 +10,7 @@ def bulk_upsert( model_class, values: list[dict], index_elements: list[str], - update_columns: list[str], + update_columns: list[str | tuple], ): """ Create or update the specified values in a table. @@ -50,7 +50,16 @@ def bulk_upsert( # The columns to use to find matching rows. index_elements=index_elements, # The columns to update. - set_={element: getattr(base.excluded, element) for element in update_columns}, + set_={ + # For tuples include the two elements as the key and value of the dict + # For strings use value: excluded.value by default + (element[0] if isinstance(element, tuple) else element): ( + element[1] + if isinstance(element, tuple) + else getattr(base.excluded, element) + ) + for element in update_columns + }, ).returning(*index_elements_columns) result = db.execute(stmt) diff --git a/lms/services/user.py b/lms/services/user.py index 53e57dd632..6a7dbca286 100644 --- a/lms/services/user.py +++ b/lms/services/user.py @@ -1,6 +1,6 @@ from functools import lru_cache -from sqlalchemy import select +from sqlalchemy import func, select, text from sqlalchemy.exc import NoResultFound from sqlalchemy.sql import Select @@ -89,7 +89,18 @@ def upsert_lms_user(self, user: User, lti_params: LTIParams) -> LMSUser: } ], index_elements=["h_userid"], - update_columns=["updated", "display_name", "email", "lti_v13_user_id"], + update_columns=[ + "updated", + "display_name", + "email", + ( + "lti_v13_user_id", + func.coalesce( + text('"excluded"."lti_v13_user_id"'), + text('"lms_user"."lti_v13_user_id"'), + ), + ), + ], ).one() bulk_upsert( self._db,