diff --git a/cdd/__init__.py b/cdd/__init__.py index 3e0395ae..6f0e3b7a 100644 --- a/cdd/__init__.py +++ b/cdd/__init__.py @@ -9,7 +9,7 @@ from logging import getLogger as get_logger __author__ = "Samuel Marks" # type: str -__version__ = "0.0.99rc43" # type: str +__version__ = "0.0.99rc44" # type: str __description__ = ( "Open API to/fro routes, models, and tests. " "Convert between docstrings, classes, methods, argparse, pydantic, and SQLalchemy." diff --git a/cdd/shared/parse/utils/parser_utils.py b/cdd/shared/parse/utils/parser_utils.py index e01d5781..00b87bc3 100644 --- a/cdd/shared/parse/utils/parser_utils.py +++ b/cdd/shared/parse/utils/parser_utils.py @@ -62,10 +62,7 @@ def ir_merge(target, other): target["params"] = other["params"] elif other["params"]: target_params, other_params = map(itemgetter("params"), (target, other)) - - merge_params(other_params, target_params) - - target["params"] = target_params + target["params"] = merge_params(other_params, target_params) if "return_type" not in (target.get("returns") or iter(())): target["returns"] = other["returns"] @@ -110,6 +107,7 @@ def merge_params(other_params, target_params): merge_present_params(other_params[name], target_params[name]) for name in other_params.keys() - target_params.keys(): target_params[name] = other_params[name] + return target_params def merge_present_params(other_param, target_param): diff --git a/cdd/sqlalchemy/utils/shared_utils.py b/cdd/sqlalchemy/utils/shared_utils.py index f8449831..beb062dc 100644 --- a/cdd/sqlalchemy/utils/shared_utils.py +++ b/cdd/sqlalchemy/utils/shared_utils.py @@ -3,7 +3,7 @@ """ import ast -from ast import Call, Expr, Load, Name, Subscript, Tuple, keyword +from ast import Call, Expr, Load, Name, Subscript, Tuple, expr, keyword from operator import attrgetter from typing import Optional, cast @@ -82,13 +82,14 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql): return _param.get("default") == cdd.shared.ast_utils.NoneStr, None elif _param["typ"].startswith("Optional["): _param["typ"] = _param["typ"][len("Optional[") : -1] - nullable = True + nullable: bool = True if "Literal[" in _param["typ"]: parsed_typ: Call = cast( Call, cdd.shared.ast_utils.get_value(ast.parse(_param["typ"]).body[0]) ) - if parsed_typ.value.id != "Literal": - return nullable, parsed_typ.value + assert parsed_typ.value.id == "Literal", "Expected `Literal` got: {!r}".format( + parsed_typ.value.id + ) val = cdd.shared.ast_utils.get_value(parsed_typ.slice) ( args.append( @@ -112,7 +113,7 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql): else _update_args_infer_typ_sqlalchemy_for_scalar(_param, args, x_typ_sql) ) elif _param["typ"].startswith("List["): - after_generic = _param["typ"][len("List[") :] + after_generic: str = _param["typ"][len("List[") :] if "struct" in after_generic: # "," in after_generic or name: Name = Name(id="JSON", ctx=Load(), lineno=None, col_offset=None) else: @@ -175,42 +176,53 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql): ) ) elif _param.get("typ").startswith("Union["): - # Hack to remove the union type. Enum parse seems to be incorrect? - union_typ: Subscript = cast(Subscript, ast.parse(_param["typ"]).body[0]) - assert isinstance( - union_typ.value, Subscript - ), "Expected `Subscript` got `{type_name}`".format( - type_name=type(union_typ.value).__name__ - ) - union_typ_tuple = ( - union_typ.value.slice if PY_GTE_3_9 else union_typ.value.slice.value - ) - assert isinstance( - union_typ_tuple, Tuple - ), "Expected `Tuple` got `{type_name}`".format( - type_name=type(union_typ_tuple).__name__ - ) - assert ( - len(union_typ_tuple.elts) == 2 - ), "Expected length of 2 got `{tuple_len}`".format( - tuple_len=len(union_typ_tuple.elts) - ) - left, right = map(attrgetter("id"), union_typ_tuple.elts) - args.append( - Name( - ( - cdd.sqlalchemy.utils.emit_utils.typ2column_type[right] - if right in cdd.sqlalchemy.utils.emit_utils.typ2column_type - else cdd.sqlalchemy.utils.emit_utils.typ2column_type.get(left, left) - ), - Load(), - lineno=None, - col_offset=None, - ) - ) + args.append(_handle_union_of_length_2(_param["typ"])) else: _update_args_infer_typ_sqlalchemy_for_scalar(_param, args, x_typ_sql) return nullable, None +def _handle_union_of_length_2(typ): + """ + Internal function to turn `str` to `Name` + + :param typ: `str` which evaluates to `ast.Subscript` + :type typ: ```str``` + + :return: Parsed out name + :rtype: ```Name``` + """ + # Hack to remove the union type. Enum parse seems to be incorrect? + union_typ: Subscript = cast(Subscript, ast.parse(typ).body[0]) + assert isinstance( + union_typ.value, Subscript + ), "Expected `Subscript` got `{type_name}`".format( + type_name=type(union_typ.value).__name__ + ) + union_typ_tuple: expr = ( + union_typ.value.slice if PY_GTE_3_9 else union_typ.value.slice.value + ) + assert isinstance( + union_typ_tuple, Tuple + ), "Expected `Tuple` got `{type_name}`".format( + type_name=type(union_typ_tuple).__name__ + ) + assert ( + len(union_typ_tuple.elts) == 2 + ), "Expected length of 2 got `{tuple_len}`".format( + tuple_len=len(union_typ_tuple.elts) + ) + left, right = map(attrgetter("id"), union_typ_tuple.elts) + return Name( + ( + cdd.sqlalchemy.utils.emit_utils.typ2column_type[right] + if right in cdd.sqlalchemy.utils.emit_utils.typ2column_type + else cdd.sqlalchemy.utils.emit_utils.typ2column_type.get(left, left) + ), + Load(), + lineno=None, + col_offset=None, + ) + + __all__ = ["update_args_infer_typ_sqlalchemy"] diff --git a/cdd/tests/test_parse/test_parser_utils.py b/cdd/tests/test_parse/test_parser_utils.py index 90be4824..38d11740 100644 --- a/cdd/tests/test_parse/test_parser_utils.py +++ b/cdd/tests/test_parse/test_parser_utils.py @@ -27,6 +27,23 @@ class TestParserUtils(TestCase): """Test class for parser_utils""" + def test_get_source_raises(self) -> None: + """Tests that `get_source` raises an exception""" + with self.assertRaises(TypeError): + get_source(None) + + def raise_os_error(_): + """raise_OSError""" + raise OSError + + with patch("inspect.getsourcelines", raise_os_error), self.assertRaises( + OSError + ): + get_source(min) + + with patch("inspect.getsourcefile", lambda _: None): + self.assertIsNone(get_source(raise_os_error)) + def test_ir_merge_empty(self) -> None: """Tests for `ir_merge` when both are empty""" target = {"params": OrderedDict(), "returns": None} @@ -250,22 +267,14 @@ def test_infer_raise(self) -> None: with self.assertRaises(NotImplementedError): cdd.shared.parse.utils.parser_utils.infer(None) - def test_get_source_raises(self) -> None: - """Tests that `get_source` raises an exception""" - with self.assertRaises(TypeError): - get_source(None) - - def raise_os_error(_): - """raise_OSError""" - raise OSError - - with patch("inspect.getsourcelines", raise_os_error), self.assertRaises( - OSError - ): - get_source(min) - - with patch("inspect.getsourcefile", lambda _: None): - self.assertIsNone(get_source(raise_os_error)) + def test_merge_params(self) -> None: + """Tests `merge_params` works""" + d0 = {"foo": "bar"} + d1 = {"can": "haz"} + self.assertDictEqual( + cdd.shared.parse.utils.parser_utils.merge_params(deepcopy(d0), d1), + {"foo": "bar", "can": "haz"}, + ) unittest_main() diff --git a/cdd/tests/test_sqlalchemy/test_emit_sqlalchemy_utils.py b/cdd/tests/test_sqlalchemy/test_emit_sqlalchemy_utils.py index b032d75c..6426182d 100644 --- a/cdd/tests/test_sqlalchemy/test_emit_sqlalchemy_utils.py +++ b/cdd/tests/test_sqlalchemy/test_emit_sqlalchemy_utils.py @@ -5,6 +5,7 @@ import ast import json from ast import ( + AST, Assign, Call, ClassDef, @@ -19,8 +20,10 @@ ) from collections import OrderedDict from copy import deepcopy +from functools import partial from os import mkdir, path from tempfile import TemporaryDirectory +from typing import Callable, List, Optional, Tuple, Union from unittest import TestCase from unittest.mock import patch @@ -29,7 +32,10 @@ from cdd.shared.ast_utils import set_value from cdd.shared.source_transformer import to_code from cdd.shared.types import IntermediateRepr -from cdd.sqlalchemy.utils.shared_utils import update_args_infer_typ_sqlalchemy +from cdd.sqlalchemy.utils.shared_utils import ( + _handle_union_of_length_2, + update_args_infer_typ_sqlalchemy, +) from cdd.tests.mocks.ir import ( intermediate_repr_empty, intermediate_repr_no_default_doc, @@ -296,6 +302,27 @@ def test_update_args_infer_typ_sqlalchemy_when_simple_array_in_typ(self) -> None # gold=Name(id="Small", ctx=Load(), lineno=None, col_offset=None), # ) + def test_update_args_infer_typ_sqlalchemy_early_exit(self) -> None: + """Tests that `update_args_infer_typ_sqlalchemy` exits early""" + _update_args_infer_typ_sqlalchemy: Callable[ + [dict], Tuple[bool, Optional[Union[List[AST], Tuple[AST]]]] + ] = partial( + update_args_infer_typ_sqlalchemy, + args=[], + name="", + nullable=True, + x_typ_sql={}, + ) + self.assertTupleEqual( + _update_args_infer_typ_sqlalchemy({"typ": None}), (False, None) + ) + self.assertTupleEqual( + _update_args_infer_typ_sqlalchemy( + {"typ": None, "default": cdd.shared.ast_utils.NoneStr}, + ), + (True, None), + ) + def test_update_with_imports_from_columns(self) -> None: """ Tests basic `cdd.sqlalchemy.utils.emit_utils.update_with_imports_from_columns` usage @@ -573,5 +600,18 @@ def test_rewrite_fk(self) -> None: gold=column_fk_gold, ) + def test__handle_union_of_length_2(self) -> None: + """Tests that `_handle_union_of_length_2` works""" + run_ast_test( + self, + gen_ast=_handle_union_of_length_2("Union[int, float]"), + gold=Name( + "Float", + Load(), + lineno=None, + col_offset=None, + ), + ) + unittest_main()