Skip to content

Commit

Permalink
ref: type of in_obj in create and update methods
Browse files Browse the repository at this point in the history
  • Loading branch information
e-kondr01 committed Nov 9, 2023
1 parent 8072739 commit 193887a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 40 deletions.
71 changes: 32 additions & 39 deletions fastapi_sqlalchemy_toolkit/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
async def create(
self,
session: AsyncSession,
in_obj: CreateSchemaType | ModelDict,
in_obj: CreateSchemaType | None = None,
refresh_attribute_names: Iterable[str] | None = None,
commit: bool = True,
**attrs: Any,
Expand All @@ -108,23 +108,24 @@ async def create(
:param session: сессия SQLAlchemy
:param in_obj: значения полей создаваемого экземпляра модели в словаре
:param in_obj: модель Pydantic для создания объекта
:param refresh_attribute_names: названия полей, которые нужно обновить
(может использоваться для подгрузки связанных полей)
:param commit: нужно ли вызывать `session.commit()`
:param attrs: дополнительные значения полей создаваемого экземпляра
(чтобы какие-то поля можно было установить напрямую из кода,
(какие-то поля можно установить напрямую,
например, пользователя запроса)
:returns: созданный экземпляр модели
"""
if isinstance(in_obj, dict):
create_data = in_obj
else:
if in_obj:
create_data = in_obj.model_dump()
else:
create_data = {}

create_data.update(attrs)
await self.run_db_validation(session, in_obj=create_data)
db_obj = self.model(**create_data)
Expand Down Expand Up @@ -153,8 +154,8 @@ async def get(
:param select_: объект Select для SQL запроса. Если передан, то метод вернёт
экземпляр Row, а не ModelType.
Примечание: фильтрация и сортировка по связанным моделям скорее всего не будут работать
вместе с этим параметром.
Примечание: фильтрация и сортировка по связанным моделям скорее всего
не будут работать вместе с этим параметром.
:param attrs: параметры для выборки объекта. Название параметра используется как
название поля модели. Значение параметра может быть примитивным типом для
Expand Down Expand Up @@ -182,9 +183,7 @@ async def get(
return result.scalars().first()
return result.first()

async def get_or_404(
self, session: AsyncSession, **attrs: FieldFilter | Any
) -> ModelType | Row:
async def get_or_404(self, session: AsyncSession, **attrs: Any) -> ModelType | Row:
"""
Получение одного экземпляра модели или возвращение HTTP ответа 404.
Expand Down Expand Up @@ -491,7 +490,7 @@ async def update(
self,
session: AsyncSession,
db_obj: ModelType,
in_obj: UpdateSchemaType | ModelDict | None = None,
in_obj: UpdateSchemaType | None = None,
refresh_attribute_names: Iterable[str] | None = None,
commit: bool = True,
exclude_unset: bool = True,
Expand All @@ -505,10 +504,10 @@ async def update(
:param db_obj: обновляемый объект
:param in_obj: значения обновляемых полей в словаре
:param in_obj: модель Pydantic для обновления значений полей объекта
:param attrs: дополнительные значения обновляемых полей
(чтобы какие-то поля можно было установить напрямую из кода,
(какие-то поля можно установить напрямую,
например, пользователя запроса)
:param refresh_attribute_names: названия полей, которые нужно обновить
Expand All @@ -524,15 +523,13 @@ async def update(
:returns: обновлённый экземпляр модели
"""
if in_obj is None:
in_obj = {}
if isinstance(in_obj, dict):
update_data = in_obj
else:
if in_obj:
update_data = in_obj.model_dump(exclude_unset=exclude_unset)
else:
update_data = {}

update_data.update(attrs)
await self.run_db_validation(session=session, db_obj=db_obj, in_obj=update_data)
await self.run_db_validation(session, db_obj=db_obj, in_obj=update_data)
for field in update_data:
setattr(db_obj, field, update_data[field])
session.add(db_obj)
Expand Down Expand Up @@ -586,7 +583,7 @@ async def bulk_create(
if not isinstance(in_obj, dict):
in_obj = in_obj.model_dump()
in_obj.update(**attrs)
await self.run_db_validation(in_obj=in_obj, session=session)
await self.run_db_validation(session, in_obj=in_obj)
db_obj = self.model(**in_obj)
objs.append(db_obj)
session.add_all(objs)
Expand Down Expand Up @@ -635,9 +632,7 @@ async def bulk_update(
else:
update_data = in_obj.model_dump(exclude_unset=exclude_unset)
update_data.update(attrs)
await self.run_db_validation(
session=session, db_obj=obj, in_obj=update_data
)
await self.run_db_validation(session, db_obj=obj, in_obj=update_data)
for field in update_data:
setattr(obj, field, update_data[field])
session.add(obj)
Expand All @@ -655,8 +650,8 @@ async def bulk_update(
async def run_db_validation(
self,
session: AsyncSession,
in_obj: ModelDict,
db_obj: ModelType | None = None,
in_obj: ModelDict | None = None,
) -> ModelDict:
"""
Выполнить валидацию на соответствие ограничениям БД.
Expand Down Expand Up @@ -698,9 +693,9 @@ def get_joins(
and order_by.field.parent.class_ != self.model
):
models_to_join.add(order_by.field.parent.class_)
for filter in kwargs.values():
if isinstance(filter, FieldFilter) and filter.model:
models_to_join.add(filter.model)
for field_filter in kwargs.values():
if isinstance(field_filter, FieldFilter) and field_filter.model:
models_to_join.add(field_filter.model)
for model in models_to_join:
if model in self.models_to_relationship_attrs:
joined_query = joined_query.outerjoin(
Expand Down Expand Up @@ -736,16 +731,16 @@ def handle_optional_filters(
`null_query_values` и установленным `nullable_q=True`.
"""
result_filters = {}
for name, filter in filters.items():
for name, field_filter in filters.items():
if (
isinstance(filter, FieldFilter)
and filter.nullable_q
and filter.value in null_query_values
isinstance(field_filter, FieldFilter)
and field_filter.nullable_q
and field_filter.value in null_query_values
):
filter.value = None
result_filters[name] = filter
if filter != None:
result_filters[name] = filter
field_filter.value = None
result_filters[name] = field_filter
if field_filter != None:
result_filters[name] = field_filter
return result_filters

def add_reverse_relation_filter_expression(
Expand Down Expand Up @@ -797,8 +792,8 @@ def get_list_query(
self,
base_query: Select,
order_by: OrderingField | None,
options: List[Any],
filter_by: dict | None = None,
options: Any | None = None,
where: Any | None = None,
**attrs: FieldFilter | Any,
):
Expand All @@ -819,8 +814,6 @@ def get_list_query(
query = query.options(option)
if where is not None:
query = query.where(where)
elif hasattr(self.model, "created_at"):
query = query.order_by(self.model.created_at.desc())
return query

async def validate_fk_exists(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fastapi_sqlalchemy_toolkit"
version = "0.5.1"
version = "0.5.2"
authors = [
{ name="Egor Kondrashov", email="[email protected]" },
]
Expand Down

0 comments on commit 193887a

Please sign in to comment.