diff --git a/docs/persistence.rst b/docs/persistence.rst index da7f488c..10f36871 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -147,9 +147,10 @@ following table: - ``Date(format="yyyy-MM-dd", required=True)`` To type a field as optional, the standard ``Optional`` modifier from the Python -``typing`` package can be used. The ``List`` modifier can be added to a field -to convert it to an array, similar to using the ``multi=True`` argument on the -field object. +``typing`` package can be used. When using Python 3.10 or newer, "pipe" syntax +can also be used, by adding ``| None`` to a type. The ``List`` modifier can be +added to a field to convert it to an array, similar to using the ``multi=True`` +argument on the field object. .. code:: python @@ -157,6 +158,7 @@ field object. class MyDoc(Document): pub_date: Optional[datetime] # same as pub_date = Date() + middle_name: str | None # same as middle_name = Text() authors: List[str] # same as authors = Text(multi=True, required=True) comments: Optional[List[str]] # same as comments = Text(multi=True) diff --git a/elasticsearch_dsl/document_base.py b/elasticsearch_dsl/document_base.py index 67eae0ab..a7026778 100644 --- a/elasticsearch_dsl/document_base.py +++ b/elasticsearch_dsl/document_base.py @@ -29,9 +29,15 @@ Tuple, TypeVar, Union, + get_args, overload, ) +try: + from types import UnionType # type: ignore[attr-defined] +except ImportError: + UnionType = None + from typing_extensions import dataclass_transform from .exceptions import ValidationException @@ -203,6 +209,14 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]): if skip or type_ == ClassVar: # skip ClassVar attributes continue + if type(type_) is UnionType: + # a union given with the pipe syntax + args = get_args(type_) + if len(args) == 2 and args[1] is type(None): + required = False + type_ = type_.__args__[0] + else: + raise TypeError("Unsupported union") field = None field_args: List[Any] = [] field_kwargs: Dict[str, Any] = {} diff --git a/tests/_async/test_document.py b/tests/_async/test_document.py index 00a570b2..933d1809 100644 --- a/tests/_async/test_document.py +++ b/tests/_async/test_document.py @@ -24,6 +24,7 @@ import codecs import ipaddress import pickle +import sys from datetime import datetime from hashlib import md5 from typing import Any, ClassVar, Dict, List, Optional @@ -791,6 +792,36 @@ class TypedDoc(AsyncDocument): } +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python 3.10") +def test_doc_with_pipe_type_hints() -> None: + with pytest.raises(TypeError): + + class BadlyTypedDoc(AsyncDocument): + s: str + f: str | int | None # type: ignore[syntax] + + class TypedDoc(AsyncDocument): + s: str + f1: str | None # type: ignore[syntax] + f2: M[int | None] # type: ignore[syntax] + f3: M[datetime | None] # type: ignore[syntax] + + props = TypedDoc._doc_type.mapping.to_dict()["properties"] + assert props == { + "s": {"type": "text"}, + "f1": {"type": "text"}, + "f2": {"type": "integer"}, + "f3": {"type": "date"}, + } + + doc = TypedDoc() + with raises(ValidationException) as exc_info: + doc.full_clean() + assert set(exc_info.value.args[0].keys()) == {"s"} + doc.s = "s" + doc.full_clean() + + def test_instrumented_field() -> None: class Child(InnerDoc): st: M[str] diff --git a/tests/_sync/test_document.py b/tests/_sync/test_document.py index dddfc688..919afdf7 100644 --- a/tests/_sync/test_document.py +++ b/tests/_sync/test_document.py @@ -24,6 +24,7 @@ import codecs import ipaddress import pickle +import sys from datetime import datetime from hashlib import md5 from typing import Any, ClassVar, Dict, List, Optional @@ -791,6 +792,36 @@ class TypedDoc(Document): } +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python 3.10") +def test_doc_with_pipe_type_hints() -> None: + with pytest.raises(TypeError): + + class BadlyTypedDoc(Document): + s: str + f: str | int | None # type: ignore[syntax] + + class TypedDoc(Document): + s: str + f1: str | None # type: ignore[syntax] + f2: M[int | None] # type: ignore[syntax] + f3: M[datetime | None] # type: ignore[syntax] + + props = TypedDoc._doc_type.mapping.to_dict()["properties"] + assert props == { + "s": {"type": "text"}, + "f1": {"type": "text"}, + "f2": {"type": "integer"}, + "f3": {"type": "date"}, + } + + doc = TypedDoc() + with raises(ValidationException) as exc_info: + doc.full_clean() + assert set(exc_info.value.args[0].keys()) == {"s"} + doc.s = "s" + doc.full_clean() + + def test_instrumented_field() -> None: class Child(InnerDoc): st: M[str]