Skip to content

Commit

Permalink
Merge pull request #11 from dka58/main
Browse files Browse the repository at this point in the history
Fix dynamic joining in repository
  • Loading branch information
iam-abbas authored Oct 5, 2023
2 parents bf9de0b + b3a4a0c commit e6aee9a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
5 changes: 4 additions & 1 deletion app/repositories/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ async def get_by_author_id(
query = await self._query(join_)
query = await self._get_by(query, "task_author_id", author_id)

if join_ is not None:
return await self.all_unique(query)

return await self._all(query)

async def _join_author(self, query: Select) -> Select:
def _join_author(self, query: Select) -> Select:
"""
Join the author relationship.
Expand Down
10 changes: 9 additions & 1 deletion app/repositories/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ async def get_by_username(
"""
query = await self._query(join_)
query = query.filter(User.username == username)

if join_ is not None:
return await self.all_unique(query)

return await self._one_or_none(query)

async def get_by_email(
Expand All @@ -36,9 +40,13 @@ async def get_by_email(
"""
query = await self._query(join_)
query = query.filter(User.email == email)

if join_ is not None:
return await self.all_unique(query)

return await self._one_or_none(query)

async def _join_tasks(self, query: Select) -> Select:
def _join_tasks(self, query: Select) -> Select:
"""
Join tasks.
Expand Down
25 changes: 17 additions & 8 deletions core/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ async def get_all(
:param join_: The joins to make.
:return: A list of model instances.
"""
query = await self._query(join_)
query = self._query(join_)
query = query.offset(skip).limit(limit)

if join_ is not None:
return await self.all_unique(query)

return await self._all(query)

async def get_by(
Expand All @@ -61,9 +64,11 @@ async def get_by(
:param join_: The joins to make.
:return: The model instance.
"""
query = await self._query(join_)
query = self._query(join_)
query = await self._get_by(query, field, value)

if join_ is not None:
return await self.all_unique(query)
if unique:
return await self._one(query)

Expand All @@ -78,7 +83,7 @@ async def delete(self, model: ModelType) -> None:
"""
self.session.delete(model)

async def _query(
def _query(
self,
join_: set[str] | None = None,
order_: dict | None = None,
Expand All @@ -91,8 +96,8 @@ async def _query(
:return: A callable that can be used to query the model.
"""
query = select(self.model_class)
query = await self._maybe_join(query, join_)
query = await self._maybe_ordered(query, order_)
query = self._maybe_join(query, join_)
query = self._maybe_ordered(query, order_)

return query

Expand All @@ -106,6 +111,10 @@ async def _all(self, query: Select) -> list[ModelType]:
query = await self.session.scalars(query)
return query.all()

async def _all_unique(self, query: Select) -> list[ModelType]:
result = await self.session.execute(query)
return result.unique().scalars().all()

async def _first(self, query: Select) -> ModelType | None:
"""
Returns the first result from the query.
Expand Down Expand Up @@ -184,7 +193,7 @@ async def _get_by(self, query: Select, field: str, value: Any) -> Select:
"""
return query.where(getattr(self.model_class, field) == value)

async def _maybe_join(self, query: Select, join_: set[str] | None = None) -> Select:
def _maybe_join(self, query: Select, join_: set[str] | None = None) -> Select:
"""
Returns the query with the given joins.
Expand All @@ -200,7 +209,7 @@ async def _maybe_join(self, query: Select, join_: set[str] | None = None) -> Sel

return reduce(self._add_join_to_query, join_, query)

async def _maybe_ordered(self, query: Select, order_: dict | None = None) -> Select:
def _maybe_ordered(self, query: Select, order_: dict | None = None) -> Select:
"""
Returns the query ordered by the given column.
Expand All @@ -218,7 +227,7 @@ async def _maybe_ordered(self, query: Select, order_: dict | None = None) -> Sel

return query

async def _add_join_to_query(self, query: Select, join_: set[str]) -> Select:
def _add_join_to_query(self, query: Select, join_: set[str]) -> Select:
"""
Returns the query with the given join.
Expand Down

0 comments on commit e6aee9a

Please sign in to comment.