Skip to content

Commit

Permalink
Fixing several bugs where return variables cant be parsed correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
Masara committed Aug 18, 2024
1 parent 159f5a2 commit 484b195
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 112 deletions.
191 changes: 114 additions & 77 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
WildcardImport,
)
from ._mypy_helpers import (
find_return_stmts_recursive,
find_stmts_recursive,
get_argument_kind,
get_classdef_definitions,
get_funcdef_definitions,
Expand All @@ -58,6 +58,7 @@ def __init__(
aliases: dict[str, set[str]],
type_source_preference: TypeSourcePreference,
type_source_warning: TypeSourceWarning,
is_test_run: bool = False,
) -> None:
self.docstring_parser: AbstractDocstringParser = docstring_parser
self.type_source_preference = type_source_preference
Expand All @@ -68,6 +69,7 @@ def __init__(
self.mypy_file: mp_nodes.MypyFile | None = None
# We gather type var types used as a parameter type in a function
self.type_var_types: set[sds_types.TypeVarType] = set()
self.is_test_run = is_test_run

def enter_moduledef(self, node: mp_nodes.MypyFile) -> None:
self.mypy_file = node
Expand Down Expand Up @@ -202,8 +204,14 @@ def enter_classdef(self, node: mp_nodes.ClassDef) -> None:
):
inherits_from_exception = True

if hasattr(superclass, "fullname") or hasattr(superclass, "name"):
superclass_qname = getattr(superclass, "fullname", "") or getattr(superclass, "name", "")
if hasattr(superclass, "fullname"):
superclass_qname = superclass.fullname

if not superclass_qname and hasattr(superclass, "name"):
superclass_qname = superclass.name
if hasattr(superclass, "expr") and isinstance(superclass.expr, mp_nodes.NameExpr):
superclass_qname = f"{superclass.expr.name}.{superclass_qname}"

superclass_name = superclass_qname.split(".")[-1]

# Check if the superclass name is an alias and find the real name
Expand Down Expand Up @@ -428,10 +436,10 @@ def enter_assignmentstmt(self, node: mp_nodes.AssignmentStmt) -> None:
if hasattr(lvalue, "items"):
for item in lvalue.items:
names.append(item.name)
else:
if not hasattr(lvalue, "name"): # pragma: no cover
raise AttributeError("Expected lvalue to have attribtue 'name'.")
names.append(getattr(lvalue, "name", ""))
elif hasattr(lvalue, "name"):
names.append(lvalue.name)
else: # pragma: no cover
raise AttributeError("Expected lvalue to have attribtue 'name'.")

for name in names:
assignments.append(
Expand Down Expand Up @@ -589,91 +597,122 @@ def _remove_assignments(func_defn: list, type_: AbstractType) -> AbstractType:
in the function itself (assignment). If this is not the case, we can assume that they are imported or from
outside the funciton.
"""
if not isinstance(type_, sds_types.NamedType | sds_types.TupleType):
return type_

found_types = type_.types if isinstance(type_, sds_types.TupleType) else [type_]
actual_types: list[AbstractType] = []
if isinstance(type_, sds_types.NamedType | sds_types.TupleType):
if isinstance(type_, sds_types.TupleType):
found_types = type_.types
else:
found_types = [type_]
assignment_stmts = find_stmts_recursive(stmt_type=mp_nodes.AssignmentStmt, stmts=func_defn)

for found_type in found_types:
if isinstance(found_type, sds_types.NamedType):
is_assignment = False
found_type_name = found_type.name
for found_type in found_types:
if not isinstance(found_type, sds_types.NamedType):
continue

Check warning on line 609 in src/safeds_stubgen/api_analyzer/_ast_visitor.py

View check run for this annotation

Codecov / codecov/patch

src/safeds_stubgen/api_analyzer/_ast_visitor.py#L609

Added line #L609 was not covered by tests

for defn in func_defn:
if not isinstance(defn, mp_nodes.AssignmentStmt):
continue
is_assignment = False
found_type_name = found_type.name

for lvalue in defn.lvalues:
if isinstance(lvalue, mp_nodes.TupleExpr):
name_expressions = lvalue.items
else:
name_expressions = [lvalue]
for stmt in assignment_stmts:
if not isinstance(stmt, mp_nodes.AssignmentStmt): # pragma: no cover
continue

for expr in name_expressions:
if isinstance(expr, mp_nodes.NameExpr) and found_type_name == expr.name:
is_assignment = True
break
if is_assignment:
break
for lvalue in stmt.lvalues:
name_expressions = lvalue.items if isinstance(lvalue, mp_nodes.TupleExpr) else [lvalue]

if is_assignment:
for expr in name_expressions:
if isinstance(expr, mp_nodes.NameExpr) and found_type_name == expr.name:
is_assignment = True
break

if is_assignment:
actual_types.append(sds_types.UnknownType())
else:
actual_types.append(found_type)
break

if len(actual_types) > 1:
type_ = sds_types.TupleType(types=actual_types)
elif len(actual_types) == 1:
type_ = actual_types[0]
if is_assignment:
break

if is_assignment:
actual_types.append(sds_types.UnknownType())
else:
type_ = sds_types.UnknownType()
return type_
actual_types.append(found_type)

if len(actual_types) > 1:
return sds_types.TupleType(types=actual_types)

Check warning on line 637 in src/safeds_stubgen/api_analyzer/_ast_visitor.py

View check run for this annotation

Codecov / codecov/patch

src/safeds_stubgen/api_analyzer/_ast_visitor.py#L637

Added line #L637 was not covered by tests
elif len(actual_types) == 1:
return actual_types[0]
return sds_types.UnknownType()

Check warning on line 640 in src/safeds_stubgen/api_analyzer/_ast_visitor.py

View check run for this annotation

Codecov / codecov/patch

src/safeds_stubgen/api_analyzer/_ast_visitor.py#L640

Added line #L640 was not covered by tests

def _infer_type_from_return_stmts(self, func_node: mp_nodes.FuncDef) -> sds_types.TupleType | None:
# To infer the type, we iterate through all return statements we find in the function
# To infer the possible result types, we iterate through all return statements we find in the function
func_defn = get_funcdef_definitions(func_node)
return_stmts = find_return_stmts_recursive(func_defn)
if return_stmts:
types = set()
for return_stmt in return_stmts:
if return_stmt.expr is None: # pragma: no cover
return_stmts = find_stmts_recursive(mp_nodes.ReturnStmt, func_defn)
if not return_stmts:
return None

types = []
for return_stmt in return_stmts:
if not isinstance(return_stmt, mp_nodes.ReturnStmt): # pragma: no cover
continue

if return_stmt.expr is not None and hasattr(return_stmt.expr, "node"):
if isinstance(return_stmt.expr.node, mp_nodes.FuncDef | mp_nodes.Decorator):
# In this case we have an inner function which the outer function returns.
continue
if (
isinstance(return_stmt.expr.node, mp_nodes.Var)
and hasattr(return_stmt.expr, "name")
and return_stmt.expr.name in func_node.arg_names
and return_stmt.expr.node.type is not None
):
# In this case the return value is a parameter of the function
type_ = self.mypy_type_to_abstract_type(return_stmt.expr.node.type)
types.append(type_)
continue

if not isinstance(return_stmt.expr, mp_nodes.CallExpr | mp_nodes.MemberExpr):
if not isinstance(return_stmt.expr, mp_nodes.CallExpr | mp_nodes.MemberExpr):
if isinstance(return_stmt.expr, mp_nodes.ConditionalExpr):
# If the return statement is a conditional expression we parse the "if" and "else" branches
if isinstance(return_stmt.expr, mp_nodes.ConditionalExpr):
for conditional_branch in [return_stmt.expr.if_expr, return_stmt.expr.else_expr]:
if conditional_branch is None: # pragma: no cover
continue

if not isinstance(conditional_branch, mp_nodes.CallExpr | mp_nodes.MemberExpr):
type_ = mypy_expression_to_sds_type(conditional_branch)
if isinstance(type_, sds_types.NamedType | sds_types.TupleType):
types.add(type_)
elif hasattr(return_stmt.expr, "node") and getattr(return_stmt.expr.node, "is_self", False):
# The result type is an instance of the parent class
expr_type = return_stmt.expr.node.type.type
types.add(sds_types.NamedType(name=expr_type.name, qname=expr_type.fullname))
else:
type_ = mypy_expression_to_sds_type(return_stmt.expr)
type_ = self._remove_assignments(func_defn, type_)
for cond_branch in [return_stmt.expr.if_expr, return_stmt.expr.else_expr]:
if cond_branch is None: # pragma: no cover
continue

if isinstance(type_, sds_types.NamedType | sds_types.TupleType):
types.add(type_)
if not isinstance(cond_branch, mp_nodes.CallExpr | mp_nodes.MemberExpr):
if (
hasattr(cond_branch, "node")
and isinstance(cond_branch.node, mp_nodes.Var)
and cond_branch.node.type is not None
):
# In this case the return value is a parameter of the function
type_ = self.mypy_type_to_abstract_type(cond_branch.node.type)
else:
type_ = mypy_expression_to_sds_type(cond_branch)
types.append(type_)
elif (
return_stmt.expr is not None
and hasattr(return_stmt.expr, "node")
and getattr(return_stmt.expr.node, "is_self", False)
):
# The result type is an instance of the parent class
expr_type = return_stmt.expr.node.type.type
types.append(sds_types.NamedType(name=expr_type.name, qname=expr_type.fullname))

Check warning on line 694 in src/safeds_stubgen/api_analyzer/_ast_visitor.py

View check run for this annotation

Codecov / codecov/patch

src/safeds_stubgen/api_analyzer/_ast_visitor.py#L693-L694

Added lines #L693 - L694 were not covered by tests
elif isinstance(return_stmt.expr, mp_nodes.TupleExpr):
all_types = []
for item in return_stmt.expr.items:
if hasattr(item, "node") and isinstance(item.node, mp_nodes.Var) and item.node.type is not None:
# In this case the return value is a parameter of the function
type_ = self.mypy_type_to_abstract_type(item.node.type)
else:
type_ = mypy_expression_to_sds_type(item)
type_ = self._remove_assignments(func_defn, type_)
all_types.append(type_)
types.append(sds_types.TupleType(types=all_types))
else:
# Lastly, we have a mypy expression object, which we have to parse
if return_stmt.expr is None: # pragma: no cover
continue

# We have to sort the list for the snapshot tests
return_stmt_types = list(types)
return_stmt_types.sort(
key=lambda x: (x.name if isinstance(x, sds_types.NamedType) else str(len(x.types))),
)
type_ = mypy_expression_to_sds_type(return_stmt.expr)
type_ = self._remove_assignments(func_defn, type_)
types.append(type_)

return sds_types.TupleType(types=return_stmt_types)
return None
return sds_types.TupleType(types=types)

@staticmethod
def _create_inferred_results(
Expand Down Expand Up @@ -705,12 +744,12 @@ def _create_inferred_results(
result_array: list[list[AbstractType]] = []
longest_inner_list = 1
for type_ in results.types:
if isinstance(type_, sds_types.NamedType):
if not isinstance(type_, sds_types.TupleType):
if result_array:
result_array[0].append(type_)
else:
result_array.append([type_])
elif isinstance(type_, sds_types.TupleType):
else:
for i, type__ in enumerate(type_.types):
if len(result_array) > i:
if type__ not in result_array[i]:
Expand All @@ -720,8 +759,6 @@ def _create_inferred_results(
longest_inner_list = len(result_array[i])
else:
result_array.append([type__])
else: # pragma: no cover
raise TypeError(f"Expected NamedType or TupleType, received {type(type_)}")

# If there are any arrays longer than others, these "others" are optional types and can be None
none_element = sds_types.NamedType(name="None", qname="builtins.None")
Expand Down
1 change: 1 addition & 0 deletions src/safeds_stubgen/api_analyzer/_get_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_api(
aliases=aliases,
type_source_preference=type_source_preference,
type_source_warning=type_source_warning,
is_test_run=is_test_run,
)
walker = ASTWalker(handler=callable_visitor)

Expand Down
29 changes: 16 additions & 13 deletions src/safeds_stubgen/api_analyzer/_mypy_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,29 @@ def get_argument_kind(arg: mp_nodes.Argument) -> ParameterAssignment:
raise ValueError("Could not find an appropriate parameter assignment.")


def find_return_stmts_recursive(stmts: list[mp_nodes.Statement] | list[mp_nodes.Block]) -> list[mp_nodes.ReturnStmt]:
return_stmts = []
def find_stmts_recursive(
stmt_type: type[mp_nodes.Statement],
stmts: list[mp_nodes.Statement] | list[mp_nodes.Block],
) -> list[mp_nodes.Statement]:
found_stmts = []
for stmt in stmts:
if isinstance(stmt, mp_nodes.IfStmt):
return_stmts += find_return_stmts_recursive(stmt.body)
if isinstance(stmt, stmt_type):
found_stmts.append(stmt)
elif isinstance(stmt, mp_nodes.IfStmt):
found_stmts += find_stmts_recursive(stmt_type, stmt.body)
if stmt.else_body:
return_stmts += find_return_stmts_recursive(stmt.else_body.body)
found_stmts += find_stmts_recursive(stmt_type, stmt.else_body.body)
elif isinstance(stmt, mp_nodes.Block):
return_stmts += find_return_stmts_recursive(stmt.body)
found_stmts += find_stmts_recursive(stmt_type, stmt.body)
elif isinstance(stmt, mp_nodes.TryStmt):
return_stmts += find_return_stmts_recursive([stmt.body])
return_stmts += find_return_stmts_recursive(stmt.handlers)
found_stmts += find_stmts_recursive(stmt_type, [stmt.body])
found_stmts += find_stmts_recursive(stmt_type, stmt.handlers)
elif isinstance(stmt, mp_nodes.MatchStmt):
return_stmts += find_return_stmts_recursive(stmt.bodies)
found_stmts += find_stmts_recursive(stmt_type, stmt.bodies)
elif isinstance(stmt, mp_nodes.WhileStmt | mp_nodes.WithStmt | mp_nodes.ForStmt):
return_stmts += find_return_stmts_recursive(stmt.body.body)
elif isinstance(stmt, mp_nodes.ReturnStmt):
return_stmts.append(stmt)
found_stmts += find_stmts_recursive(stmt_type, stmt.body.body)

return return_stmts
return found_stmts


def mypy_variance_parser(mypy_variance_type: Literal[0, 1, 2]) -> VarianceKind:
Expand Down
4 changes: 4 additions & 0 deletions src/safeds_stubgen/stubs_generator/_generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def _create_outside_package_class(
) -> set[str]:
path_parts = class_path.split(".")

# There are cases where we could not correctly parse or find the origin of a variable, which is then put into
# the imports. But since these variables have no qname and only consist of a name we cannot create seperate files
# for them.
# E.g.: `x: numpy.some_class; ...; return x` would have the result type parsed as just "numpy"
if len(path_parts) == 1:
return created_module_paths

Check warning on line 133 in src/safeds_stubgen/stubs_generator/_generate_stubs.py

View check run for this annotation

Codecov / codecov/patch

src/safeds_stubgen/stubs_generator/_generate_stubs.py#L133

Added line #L133 was not covered by tests

Expand Down
5 changes: 5 additions & 0 deletions tests/data/various_modules_package/class_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Self, overload
from . import unknown_source

from tests.data.main_package.another_path.another_module import yetAnotherClass

Expand Down Expand Up @@ -102,3 +103,7 @@ def stale(self):
@stale.setter
def stale(self, val):
self._stale = val


class ClassWithImportedSuperclasses(unknown_source.UnknownClass):
pass
32 changes: 32 additions & 0 deletions tests/data/various_modules_package/function_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Optional, Literal, Any
from tests.data.main_package.another_path.another_module import AnotherClass
import griffe as gr


class FunctionModuleClassA:
Expand Down Expand Up @@ -227,3 +228,34 @@ def _f(x: int, y: int) -> int:
def ignore_assignment2(a: int, b: int):
Cxy = 3**2
return Cxy


def return_inner_function():
def return_me():
return 123

return return_me


def return_param1(a):
return a


def return_param2(a: int):
return a


def return_param3(a: int, b, c: bool):
return a if b else c


def return_param4(a: int, b, x):
if x == 0:
return a, b, a, b

return True


def return_var():
locs: list[int] | gr.Alias
return locs
Loading

0 comments on commit 484b195

Please sign in to comment.