From 25386dda8471357cc9e469275aceb79f30e5e130 Mon Sep 17 00:00:00 2001 From: Lunarmagpie Date: Sun, 16 Oct 2022 03:58:37 -0400 Subject: [PATCH] Add `sigparse.utils.unwrap` (#8) * add the utils function * add tests * fix tests * add tests * add tests * support 3.7 * args 3.7 * mypy --- sigparse/__init__.py | 2 ++ sigparse/utils.py | 54 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_utils.py | 13 +++++++++++ 3 files changed, 69 insertions(+) create mode 100644 sigparse/utils.py create mode 100644 tests/test_utils.py diff --git a/sigparse/__init__.py b/sigparse/__init__.py index 084b446..877ecbd 100644 --- a/sigparse/__init__.py +++ b/sigparse/__init__.py @@ -27,6 +27,7 @@ from sigparse._sigparse import sigparse, Parameter from sigparse._classparse import classparse, ClassVar from sigparse._pep604 import global_PEP604 +from sigparse import utils __all__: typing.Sequence[str] = ( "classparse", @@ -34,4 +35,5 @@ "Parameter", "ClassVar", "global_PEP604", + "utils", ) diff --git a/sigparse/utils.py b/sigparse/utils.py new file mode 100644 index 0000000..179aa63 --- /dev/null +++ b/sigparse/utils.py @@ -0,0 +1,54 @@ +import types +import typing + +__all__: typing.Sequence[str] = ("unwrap",) + +try: + UnionType = types.UnionType # type: ignore +except AttributeError: + UnionType = ... + +NoneType = type(None) + + +def _get_origin(typehint: typing.Any) -> typing.Any: + if hasattr(typehint, "__origin__"): + return typehint.__origin__ + + return None + + +def _get_args(typehint: typing.Any) -> typing.Any: + if hasattr(typehint, "__args__"): + return typehint.__args__ + + return None + + +def unwrap(typehint: typing.Any) -> typing.Any: + """ + Remove the `None` values from a `Union[T, U]` or `Optional[T]`. + + If one of `T` or `U` is `None`, return the non-none value. + If `T` and `U` are `None`, return `None`. + If neither `T` or `U` are `None`, return `Union[T, U]`. + If `typehint` is not a `Union` or `Option` return `typehint`. + """ + + if typehint is NoneType: + return None + + if _get_origin(typehint) not in {typing.Union, UnionType}: + return typehint + + args = _get_args(typehint) + + if not args: + return None + + hints = list(filter(lambda x: x not in {NoneType, None}, args)) + + if len(hints) == 1: + return hints[0] + + return typing.Union[hints[0], hints[1]] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..7e4774f --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,13 @@ +import typing +import sigparse + +def test_unwrap(): + assert sigparse.utils.unwrap(typing.Union[int, None]) == int + assert sigparse.utils.unwrap(typing.Union[None, None]) == None + assert sigparse.utils.unwrap(typing.Union[int, str]) == typing.Union[int, str] + + assert sigparse.utils.unwrap(typing.Optional[int]) == int + assert sigparse.utils.unwrap(typing.Optional[None]) == None + + assert sigparse.utils.unwrap(int) == int + assert sigparse.utils.unwrap(None) == None