From b3a4a0c3e2c2cc6d9269c65035c377198d7de5fb Mon Sep 17 00:00:00 2001 From: dka58 <116434336+dka58@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:21:41 +0000 Subject: [PATCH] Fix dynamic joining in repository --- app/repositories/task.py | 5 ++++- app/repositories/user.py | 10 +++++++++- core/repository/base.py | 25 +++++++++++++++++-------- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/app/repositories/task.py b/app/repositories/task.py index ad2c054..c78aad6 100644 --- a/app/repositories/task.py +++ b/app/repositories/task.py @@ -23,9 +23,12 @@ async def get_by_author_id( query = await self._query(join_) query = await self._get_by(query, "task_author_id", author_id) + if join_ is not None: + return await self.all_unique(query) + return await self._all(query) - async def _join_author(self, query: Select) -> Select: + def _join_author(self, query: Select) -> Select: """ Join the author relationship. diff --git a/app/repositories/user.py b/app/repositories/user.py index 26dd904..92ba03f 100644 --- a/app/repositories/user.py +++ b/app/repositories/user.py @@ -22,6 +22,10 @@ async def get_by_username( """ query = await self._query(join_) query = query.filter(User.username == username) + + if join_ is not None: + return await self.all_unique(query) + return await self._one_or_none(query) async def get_by_email( @@ -36,9 +40,13 @@ async def get_by_email( """ query = await self._query(join_) query = query.filter(User.email == email) + + if join_ is not None: + return await self.all_unique(query) + return await self._one_or_none(query) - async def _join_tasks(self, query: Select) -> Select: + def _join_tasks(self, query: Select) -> Select: """ Join tasks. diff --git a/core/repository/base.py b/core/repository/base.py index e6f7866..0d6db61 100644 --- a/core/repository/base.py +++ b/core/repository/base.py @@ -41,9 +41,12 @@ async def get_all( :param join_: The joins to make. :return: A list of model instances. """ - query = await self._query(join_) + query = self._query(join_) query = query.offset(skip).limit(limit) + if join_ is not None: + return await self.all_unique(query) + return await self._all(query) async def get_by( @@ -61,9 +64,11 @@ async def get_by( :param join_: The joins to make. :return: The model instance. """ - query = await self._query(join_) + query = self._query(join_) query = await self._get_by(query, field, value) + if join_ is not None: + return await self.all_unique(query) if unique: return await self._one(query) @@ -78,7 +83,7 @@ async def delete(self, model: ModelType) -> None: """ self.session.delete(model) - async def _query( + def _query( self, join_: set[str] | None = None, order_: dict | None = None, @@ -91,8 +96,8 @@ async def _query( :return: A callable that can be used to query the model. """ query = select(self.model_class) - query = await self._maybe_join(query, join_) - query = await self._maybe_ordered(query, order_) + query = self._maybe_join(query, join_) + query = self._maybe_ordered(query, order_) return query @@ -106,6 +111,10 @@ async def _all(self, query: Select) -> list[ModelType]: query = await self.session.scalars(query) return query.all() + async def _all_unique(self, query: Select) -> list[ModelType]: + result = await self.session.execute(query) + return result.unique().scalars().all() + async def _first(self, query: Select) -> ModelType | None: """ Returns the first result from the query. @@ -184,7 +193,7 @@ async def _get_by(self, query: Select, field: str, value: Any) -> Select: """ return query.where(getattr(self.model_class, field) == value) - async def _maybe_join(self, query: Select, join_: set[str] | None = None) -> Select: + def _maybe_join(self, query: Select, join_: set[str] | None = None) -> Select: """ Returns the query with the given joins. @@ -200,7 +209,7 @@ async def _maybe_join(self, query: Select, join_: set[str] | None = None) -> Sel return reduce(self._add_join_to_query, join_, query) - async def _maybe_ordered(self, query: Select, order_: dict | None = None) -> Select: + def _maybe_ordered(self, query: Select, order_: dict | None = None) -> Select: """ Returns the query ordered by the given column. @@ -218,7 +227,7 @@ async def _maybe_ordered(self, query: Select, order_: dict | None = None) -> Sel return query - async def _add_join_to_query(self, query: Select, join_: set[str]) -> Select: + def _add_join_to_query(self, query: Select, join_: set[str]) -> Select: """ Returns the query with the given join.