From f823fdff82835765727d71b3d957e94304aa3cfd Mon Sep 17 00:00:00 2001 From: themylogin Date: Sun, 15 Dec 2024 23:29:47 +0100 Subject: [PATCH] Fix nested models still having default values in models that use `ForUpdateMetaclass` (#15202) --- src/middlewared/middlewared/api/base/model.py | 11 ++++++- .../pytest/unit/api/base/test_model.py | 30 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/middlewared/middlewared/api/base/model.py b/src/middlewared/middlewared/api/base/model.py index 327d238bc3d6c..03442cdb1190c 100644 --- a/src/middlewared/middlewared/api/base/model.py +++ b/src/middlewared/middlewared/api/base/model.py @@ -133,7 +133,16 @@ def __new__(mcls, name, bases, namespaces, **kwargs): class _ForUpdateSerializerMixin(PydanticBaseModel): @model_serializer(mode="wrap") def serialize_model(self, serializer): - return {k: v for k, v in serializer(self).items() if v != undefined} + aliases = {field.alias or name: name for name, field in self.model_fields.items()} + + return { + k: v + for k, v in serializer(self).items() + if ( + (getattr(self, aliases[k]) != undefined) if k in aliases and hasattr(self, aliases[k]) + else v != undefined + ) + } def _field_for_update(field): diff --git a/src/middlewared/middlewared/pytest/unit/api/base/test_model.py b/src/middlewared/middlewared/pytest/unit/api/base/test_model.py index 04d4d188b61ef..1f1de171d1d3c 100644 --- a/src/middlewared/middlewared/pytest/unit/api/base/test_model.py +++ b/src/middlewared/middlewared/pytest/unit/api/base/test_model.py @@ -1,9 +1,11 @@ +from pydantic import Field import pytest from middlewared.api.base import (BaseModel, Excluded, excluded_field, ForUpdateMetaclass, single_argument_args, single_argument_result) from middlewared.api.base.handler.accept import accept_params from middlewared.api.base.handler.result import serialize_result +from middlewared.api.v25_04_0.pool_snapshottask import PoolSnapshotTaskCron from middlewared.service_exception import ValidationErrors @@ -64,3 +66,31 @@ class MethodResult(BaseModel): count: int assert serialize_result(MethodResult, {"name": "ivan", "count": 1}, True) == {"name": "ivan", "count": 1} + + +def test_update_with_cron(): + class CreateObjectWithCron(BaseModel): + schedule: PoolSnapshotTaskCron = Field(default_factory=PoolSnapshotTaskCron) + + class UpdateObjectWithCron(CreateObjectWithCron, metaclass=ForUpdateMetaclass): + pass + + class UpdateWithCronArgs(BaseModel): + id: int + data: UpdateObjectWithCron + + assert accept_params(UpdateWithCronArgs, [1, {}]) == [1, {}] + + +def test_update_with_alias(): + class CreateObjectWithAlias(BaseModel): + pass_: str = Field(alias="pass") + + class UpdateObjectWithAlias(CreateObjectWithAlias, metaclass=ForUpdateMetaclass): + pass + + class UpdateWithAliasArgs(BaseModel): + id: int + data: UpdateObjectWithAlias + + assert accept_params(UpdateWithAliasArgs, [1, {"pass": "1"}]) == [1, {"pass": "1"}]