Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the incompatibility between pydantic-duality and __init__() #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 69 additions & 19 deletions pydantic_duality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def _resolve_annotation(annotation, attr: str) -> Any:
tuple(_resolve_annotation(a, attr) for a in get_args(annotation)),
)
elif isinstance(annotation, UnionType):
return Union.__getitem__(tuple(_resolve_annotation(a, attr) for a in get_args(annotation)))
return Union.__getitem__(
tuple(_resolve_annotation(a, attr) for a in get_args(annotation))
)
elif get_origin(annotation) is Annotated:
return Annotated.__class_getitem__(
tuple(_resolve_annotation(a, attr) for a in get_args(annotation)),
Expand All @@ -58,7 +60,9 @@ def _alter_attrs(attrs: dict[str, object], name: str, attr: str):
if attr == PATCH_REQUEST_ATTR:
if get_origin(annotations[key]) is Annotated:
args = get_args(annotations[key])
annotations[key] = Annotated.__class_getitem__(tuple([args[0] | None, *args[1:]]))
annotations[key] = Annotated.__class_getitem__(
tuple([args[0] | None, *args[1:]])
)
elif isinstance(annotations[key], str):
annotations[key] += " | None"
else:
Expand All @@ -67,12 +71,18 @@ def _alter_attrs(attrs: dict[str, object], name: str, attr: str):
return attrs


def _lazily_initalize_models(request_cls: type, own_attr_name: str, constructor: Callable[[], Any]):
def _lazily_initalize_models(
request_cls: type, own_attr_name: str, constructor: Callable[[], Any]
):
def constructor_wrapper(*a, **kw) -> object:
obj = constructor()
obj.__request__ = request_cls
obj.__response__ = cached_classproperty(lambda cls: request_cls.__response__, RESPONSE_ATTR)
obj.__patch_request__ = cached_classproperty(lambda cls: request_cls.__patch_request__, PATCH_REQUEST_ATTR)
obj.__response__ = cached_classproperty(
lambda cls: request_cls.__response__, RESPONSE_ATTR
)
obj.__patch_request__ = cached_classproperty(
lambda cls: request_cls.__patch_request__, PATCH_REQUEST_ATTR
)
return obj

return cached_classproperty(constructor_wrapper, own_attr_name)
Expand All @@ -96,7 +106,9 @@ def __new__(
**kwargs,
) -> Self:
new_class = type.__new__(cls, name, bases, attrs)
if not bases or not any(isinstance(b, (ModelMetaclass, DualBaseModelMeta)) for b in bases):
if not bases or not any(
isinstance(b, (ModelMetaclass, DualBaseModelMeta)) for b in bases
):
raise TypeError(
f"ModelDuplicatorMeta's instances must be created with a DualBaseModel base class or a BaseModel base class."
)
Expand All @@ -108,19 +120,31 @@ def __new__(
)
elif not inspect.isclass(kwargs["__config__"]):
raise TypeError("The __config__ argument must be a class.")
elif request_suffix is None or response_suffix is None or patch_request_suffix is None:
elif (
request_suffix is None
or response_suffix is None
or patch_request_suffix is None
):
raise TypeError(
"The first instance of DualBaseModel must pass suffixes for the request, response, and patch request models."
)
new_class._generate_base_alternative_classes(request_suffix, response_suffix, kwargs)
new_class._generate_base_alternative_classes(
request_suffix, response_suffix, kwargs
)
else:
request_suffix, response_suffix, patch_request_suffix = (
request_suffix or new_class.request_suffix,
response_suffix or new_class.response_suffix,
patch_request_suffix or new_class.patch_request_suffix,
)
new_class._generate_alternative_classes(
name, bases, attrs, request_suffix, response_suffix, patch_request_suffix, kwargs
name,
bases,
attrs,
request_suffix,
response_suffix,
patch_request_suffix,
kwargs,
)

new_class.__request__.request_suffix = request_suffix # type: ignore
Expand All @@ -129,30 +153,45 @@ def __new__(

return new_class

def _generate_base_alternative_classes(self, request_suffix, response_suffix, kwargs):
def _generate_base_alternative_classes(
self, request_suffix, response_suffix, kwargs
):
class Config(kwargs["__config__"]): # type: ignore
extra = Extra.forbid

BaseRequest = ModelMetaclass(f"Base{request_suffix}", (BaseModel,), {"Config": Config})
BaseRequest = ModelMetaclass(
f"Base{request_suffix}", (BaseModel,), {"Config": Config}
)

class Config(kwargs["__config__"]):
extra = Extra.ignore

BaseResponse = ModelMetaclass(f"Base{response_suffix}", (BaseModel,), {"Config": Config})
BaseResponse = ModelMetaclass(
f"Base{response_suffix}", (BaseModel,), {"Config": Config}
)

type.__setattr__(self, "__request__", BaseRequest)
BaseRequest.__request__ = BaseRequest # type: ignore
BaseRequest.__response__ = BaseResponse # type: ignore
BaseRequest.__patch_request__ = BaseRequest # type: ignore

def _generate_alternative_classes(
self, name, bases, attrs, request_suffix, response_suffix, patch_request_suffix, kwargs
self,
name,
bases,
attrs,
request_suffix,
response_suffix,
patch_request_suffix,
kwargs,
):
anonymized_attrs = attrs.copy()
anonymized_attrs.pop("__classcell__", None)
request_bases = tuple(_resolve_annotation(b, REQUEST_ATTR) for b in bases)
request_class = ModelMetaclass(
name + request_suffix,
request_bases,
_alter_attrs(attrs, name + request_suffix, REQUEST_ATTR),
_alter_attrs(anonymized_attrs, name + request_suffix, REQUEST_ATTR),
**kwargs,
)
request_class.__response__ = _lazily_initalize_models(
Expand All @@ -161,7 +200,7 @@ def _generate_alternative_classes(
lambda: ModelMetaclass(
name + response_suffix,
tuple(_resolve_annotation(b, RESPONSE_ATTR) for b in bases),
_alter_attrs(attrs, name + response_suffix, RESPONSE_ATTR),
_alter_attrs(anonymized_attrs, name + response_suffix, RESPONSE_ATTR),
**kwargs,
),
)
Expand All @@ -171,7 +210,9 @@ def _generate_alternative_classes(
lambda: ModelMetaclass(
name + patch_request_suffix,
tuple(_resolve_annotation(b, PATCH_REQUEST_ATTR) for b in bases),
_alter_attrs(attrs, name + patch_request_suffix, PATCH_REQUEST_ATTR),
_alter_attrs(
anonymized_attrs, name + patch_request_suffix, PATCH_REQUEST_ATTR
),
**kwargs,
),
)
Expand All @@ -182,7 +223,12 @@ def _generate_alternative_classes(

def __getattribute__(self, attr: str):
# Note here that RESPONSE_ATTR and PATCH_REQUEST_ATTR goes into REQUEST_ATTR's __getattribute__ method
if attr in {REQUEST_ATTR, "__new__", "_generate_base_alternative_classes", "_generate_alternative_classes"}:
if attr in {
REQUEST_ATTR,
"__new__",
"_generate_base_alternative_classes",
"_generate_alternative_classes",
}:
return type.__getattribute__(self, attr)
return getattr(type.__getattribute__(self, REQUEST_ATTR), attr)

Expand All @@ -202,10 +248,14 @@ def __hash__(self) -> int:
return hash(self.__request__)

def __instancecheck__(cls, instance) -> bool:
return type.__instancecheck__(cls, instance) or isinstance(instance, cls.__request__)
return type.__instancecheck__(cls, instance) or isinstance(
instance, cls.__request__
)

def __subclasscheck__(cls, subclass: type):
return type.__subclasscheck__(cls, subclass) or issubclass(subclass, cls.__request__)
return type.__subclasscheck__(cls, subclass) or issubclass(
subclass, cls.__request__
)


def generate_dual_base_model(
Expand Down
77 changes: 66 additions & 11 deletions tests/test_duality.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,21 @@ class SubSchema(DualBaseModel):

def test_ignore_forbid_attrs(schemas):
assert (
schemas["A"].__request__.__response__.__response__.__request__.__response__.__request__.Config.extra
schemas[
"A"
].__request__.__response__.__response__.__request__.__response__.__request__.Config.extra
== Extra.forbid
)
assert (
schemas["A"].__request__.__response__.__response__.__request__.__response__.__patch_request__.Config.extra
schemas[
"A"
].__request__.__response__.__response__.__request__.__response__.__patch_request__.Config.extra
== Extra.forbid
)
assert (
schemas["A"].__request__.__response__.__response__.__request__.__response__.__response__.Config.extra
schemas[
"A"
].__request__.__response__.__response__.__request__.__response__.__response__.Config.extra
== Extra.ignore
)

Expand Down Expand Up @@ -218,16 +224,33 @@ class ChildSchema2(DualBaseModel):
obj: str

class Schema(DualBaseModel):
child: Annotated[ChildSchema1 | ChildSchema2, Field(discriminator="object_type")]
child: Annotated[
ChildSchema1 | ChildSchema2, Field(discriminator="object_type")
]

for object_type in (1, 2):
child_schema = Schema.parse_obj({"child": {"object_type": object_type, "obj": object_type}})
child_req_schema = Schema.__request__.parse_obj({"child": {"object_type": object_type, "obj": object_type}})
child_resp_schema = Schema.__response__.parse_obj({"child": {"object_type": object_type, "obj": object_type}})
child_schema = Schema.parse_obj(
{"child": {"object_type": object_type, "obj": object_type}}
)
child_req_schema = Schema.__request__.parse_obj(
{"child": {"object_type": object_type, "obj": object_type}}
)
child_resp_schema = Schema.__response__.parse_obj(
{"child": {"object_type": object_type, "obj": object_type}}
)

assert type(child_schema.child) is locals()[f"ChildSchema{object_type}"].__request__
assert type(child_req_schema.child) is locals()[f"ChildSchema{object_type}"].__request__
assert type(child_resp_schema.child) is locals()[f"ChildSchema{object_type}"].__response__
assert (
type(child_schema.child)
is locals()[f"ChildSchema{object_type}"].__request__
)
assert (
type(child_req_schema.child)
is locals()[f"ChildSchema{object_type}"].__request__
)
assert (
type(child_resp_schema.child)
is locals()[f"ChildSchema{object_type}"].__response__
)
with pytest.raises(ValidationError):
Schema.parse_obj(
{
Expand Down Expand Up @@ -269,7 +292,9 @@ class Schema(DualBaseModel):
)


@pytest.mark.parametrize("field_type", [Annotated[int, "Hello"], Annotated[int, "Hello", "Darkness"]])
@pytest.mark.parametrize(
"field_type", [Annotated[int, "Hello"], Annotated[int, "Hello", "Darkness"]]
)
def test_annotated_model_creation_with_regular_metadata(field_type):
class Schema(DualBaseModel):
field: field_type
Expand Down Expand Up @@ -344,3 +369,33 @@ class Schema(DualBaseModel, extra=Extra.ignore):
assert Schema.__patch_request__.Config.extra == Extra.forbid

Schema(field=1, extra=2)


def test_model_can_be_created_with_super_init_in_init():
class MyModel(DualBaseModel):
one: str

def __init__(self, **kwargs):
super().__init__(**kwargs)


def test_model_can_be_created_with_init_subclass():
class MyModel(DualBaseModel):
one: str

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

class MyModelChild(MyModel):
pass


def test_model_can_be_created_with_classmethod():
class MyModel(DualBaseModel):
one: str

@classmethod
def get_stuff(cls):
return super().parse_obj

MyModel.get_stuff()
Loading