Skip to content

Commit

Permalink
[cdd/shared/ast_utils.py] Fix Literal support in get_types ; remo…
Browse files Browse the repository at this point in the history
…ve unnecessary type restriction in `infer_imports` and fix its `type_comment` implementation ; [cdd/tests/test_shared/test_ast_utils.py] Increase test coverage of this file to 100% ; [cdd/__init__.py] Bump version
  • Loading branch information
SamuelMarks committed Mar 17, 2024
1 parent 04b7ad4 commit fc4eea2
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 19 deletions.
2 changes: 1 addition & 1 deletion cdd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logging import getLogger as get_logger

__author__ = "Samuel Marks" # type: str
__version__ = "0.0.99rc44" # type: str
__version__ = "0.0.99rc45" # type: str
__description__ = (
"Open API to/fro routes, models, and tests. "
"Convert between docstrings, classes, methods, argparse, pydantic, and SQLalchemy."
Expand Down
42 changes: 30 additions & 12 deletions cdd/shared/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from json import dumps
from operator import attrgetter, contains, inv, itemgetter, neg, not_, pos
from os import path
from typing import FrozenSet, Generator, Optional
from typing import Callable, FrozenSet, Generator, MutableSet, Optional
from typing import Tuple as TTuple
from typing import __all__ as typing__all__

import cdd.shared.source_transformer
Expand Down Expand Up @@ -2215,7 +2216,14 @@ def get_types(node):
return iter((node.value.id, node.slice.id))
elif isinstance(node.slice, Tuple):
return chain.from_iterable(
((node.value.id,), map(get_value, map(get_value, node.slice.elts)))
(
(node.value.id,),
(
iter(())
if node.value.id == "Literal"
else map(get_value, map(get_value, node.slice.elts))
),
)
)


Expand All @@ -2228,16 +2236,16 @@ def infer_imports(module, modules_to_all=DEFAULT_MODULES_TO_ALL):
- sqlalchemy
- pydantic
:param module: Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign
:type module: ```Union[ClassDef, FunctionDef, AsyncFunctionDef, Assign]```
:param module: Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign
:type module: ```Union[Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign]```
:param modules_to_all: Tuple of module_name to __all__ of module; (str) to FrozenSet[str]
:type modules_to_all: ```tuple[tuple[str, frozenset], ...]```
:return: List of imports
:rtype: ```Optional[Tuple[Union[Import, ImportFrom]]]```
:rtype: ```Optional[Tuple[Union[Import, ImportFrom], ...]]```
"""
if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign)):
if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign)):
module: Module = Module(body=[module], type_ignores=[], stmt=None)
assert isinstance(module, Module), "Expected `Module` got `{type_name}`".format(
type_name=type(module).__name__
Expand All @@ -2252,7 +2260,13 @@ def node_to_importable_name(node):
:rtype: ```Optional[str]```
"""
if getattr(node, "type_comment", None) is not None:
return node.type_comment
return (
node.type_comment
if node.type_comment in simple_types
else get_value(
get_value(get_value(ast.parse(node.type_comment).body[0]))
)
)
elif getattr(node, "annotation", None) is not None:
node = node # type: Union[AnnAssign, arg]
return node.annotation # cast(node, Union[AnnAssign, arg])
Expand All @@ -2261,7 +2275,9 @@ def node_to_importable_name(node):
else:
return None

_symbol_to_import = partial(symbol_to_import, modules_to_all=modules_to_all)
_symbol_to_import: Callable[[str], Optional[TTuple[str, str]]] = partial(
symbol_to_import, modules_to_all=modules_to_all
)

# Lots of room for optimisation here; but its probably NP-hard:
imports = tuple(
Expand Down Expand Up @@ -2352,8 +2368,10 @@ def deduplicate_sorted_imports(module):
:return: Module but with duplicate import entries in first import block removed
:rtype: ```Module```
"""
assert isinstance(module, Module)
fst_import_idx = next(
assert isinstance(module, Module), "Expected `Module` got `{}`".format(
type(module).__name__
)
fst_import_idx: Optional[int] = next(
map(
itemgetter(0),
filter(
Expand All @@ -2365,7 +2383,7 @@ def deduplicate_sorted_imports(module):
)
if fst_import_idx is None:
return module
lst_import_idx = next(
lst_import_idx: Optional[int] = next(
iter(
deque(
map(
Expand All @@ -2380,7 +2398,7 @@ def deduplicate_sorted_imports(module):
),
None,
)
name_seen = set()
name_seen: MutableSet[str] = set()

module.body = (
module.body[:fst_import_idx]
Expand Down
121 changes: 115 additions & 6 deletions cdd/tests/test_shared/test_ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
arguments,
keyword,
)
from collections import deque
from copy import deepcopy
from itertools import repeat
from os import extsep, path
Expand Down Expand Up @@ -88,6 +89,7 @@
function_adder_ast,
function_adder_str,
)
from cdd.tests.mocks.pydantic import pydantic_class_cls_def
from cdd.tests.mocks.sqlalchemy import config_decl_base_ast
from cdd.tests.utils_for_tests import inspectable_compile, run_ast_test, unittest_main

Expand Down Expand Up @@ -432,7 +434,7 @@ def test_infer_imports_with_sqlalchemy(self) -> None:
"""
imports = infer_imports(
config_decl_base_ast
) # type: Optional[Tuple[Union[Import, ImportFrom]]]
) # type: Optional[TTuple[Union[Import, ImportFrom], ...]]
self.assertIsNotNone(imports)
self.assertEqual(len(imports), 1)
run_ast_test(
Expand All @@ -455,6 +457,57 @@ def test_infer_imports_with_sqlalchemy(self) -> None:
),
)

def test_infer_imports_with_simple_node_variants(self) -> None:
"""
Test that `infer_imports` with some simple variants
"""

def inner_test(imports):
"""
Run the actual test
:param imports: The imports to compare against
:type imports: ```TList[ImportFrom]```
"""
self.assertIsNotNone(imports)
self.assertEqual(len(imports), 1)
run_ast_test(
self,
imports[0],
ImportFrom(
module="typing" if PY_GTE_3_8 else "typing_extensions",
names=[
alias(
"Literal",
None,
identifier=None,
identifier_name=None,
)
],
level=0,
),
)

deque(
map(
inner_test,
map(
infer_imports,
(
pydantic_class_cls_def,
Assign(
targets=[Name("a", Load(), lineno=None, col_offset=None)],
value=set_value("cat"),
type_comment="Literal['cat']",
expr=None,
lineno=None,
),
),
),
),
maxlen=0,
)

def test_node_to_dict(self) -> None:
"""
Tests `node_to_dict`
Expand Down Expand Up @@ -642,6 +695,7 @@ def test_get_value(self) -> None:
)
self.assertIsNone(get_value(Name(None, None)))
self.assertEqual(get_value(get_value(ast.parse("-5").body[0])), -5)
self.assertEqual(get_value(Num(n=-5, constant_value=None, string=None)), -5)

def test_set_value(self) -> None:
"""Tests that `set_value` returns the right type for the right Python version"""
Expand Down Expand Up @@ -749,21 +803,76 @@ def test_find_ast_type_fails(self) -> None:

def test_get_types(self) -> None:
"""Test that `get_types` functions correctly"""
self.assertTupleEqual(tuple(get_types(None)), tuple())
self.assertTupleEqual(tuple(get_types("str")), ("str",))
self.assertTupleEqual(
tuple(
get_types(
Subscript(
value=Name(
id="Optional", ctx=Load(), lineno=None, col_offset=None
),
slice=Name(id="Any", ctx=Load(), lineno=None, col_offset=None),
ctx=Load(),
expr_context_ctx=None,
expr_slice=None,
expr_value=None,
lineno=None,
col_offset=None,
)
)
),
("Optional", "Any"),
)
self.assertTupleEqual(
tuple(get_types("str")),
("str",),
tuple(
get_types(
Subscript(
value=Name(
id="Literal", ctx=Load(), lineno=None, col_offset=None
),
slice=Tuple(
elts=list(map(set_value, ("foo", "bar"))),
ctx=Load(),
expr=None,
lineno=None,
col_offset=None,
),
ctx=Load(),
expr_context_ctx=None,
expr_slice=None,
expr_value=None,
lineno=None,
col_offset=None,
)
)
),
("Literal",),
)
self.assertTupleEqual(
tuple(
get_types(
Subscript(
value=Name(id="Optional", ctx=Load()),
slice=Name(id="Any", ctx=Load()),
value=Name(
id="Tuple", ctx=Load(), lineno=None, col_offset=None
),
slice=Tuple(
elts=list(map(set_value, ("int", "float"))),
ctx=Load(),
expr=None,
lineno=None,
col_offset=None,
),
ctx=Load(),
expr_context_ctx=None,
expr_slice=None,
expr_value=None,
lineno=None,
col_offset=None,
)
)
),
("Optional", "Any"),
("Tuple", "int", "float"),
)

def test_to_named_class_def(self) -> None:
Expand Down

0 comments on commit fc4eea2

Please sign in to comment.