Skip to content

Commit

Permalink
track if typing.TYPE_CHECKING to warn about non runtime bindings
Browse files Browse the repository at this point in the history
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks
the bound names will not be available at runtime and may cause errors
when used in the following way::

  import typing

  if typing.TYPE_CHECKING:
    from module import Type  # some slow import or circular reference

  def method(value) -> Type:  # the import is needed by the type checker
    assert isinstance(value, Type)  # this is a runtime error

This change allows pyflakes to track what names are bound for runtime
use, and allows it to warn when a non runtime name is used in a runtime
context.
  • Loading branch information
terencehonles committed Mar 31, 2023
1 parent e19886e commit 58506b3
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 30 deletions.
115 changes: 85 additions & 30 deletions pyflakes/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ class Binding:
the node that this binding was last used.
"""

def __init__(self, name, source):
def __init__(self, name, source, runtime=True):
self.name = name
self.source = source
self.used = False
self.runtime = runtime

def __str__(self):
return self.name
Expand Down Expand Up @@ -273,8 +274,8 @@ def redefines(self, other):
class Builtin(Definition):
"""A definition created for all Python builtins."""

def __init__(self, name):
super().__init__(name, None)
def __init__(self, name, runtime=True):
super().__init__(name, None, runtime=runtime)

def __repr__(self):
return '<{} object {!r} at 0x{:x}>'.format(
Expand Down Expand Up @@ -318,10 +319,10 @@ class Importation(Definition):
@type fullName: C{str}
"""

def __init__(self, name, source, full_name=None):
def __init__(self, name, source, full_name=None, runtime=True):
self.fullName = full_name or name
self.redefined = []
super().__init__(name, source)
super().__init__(name, source, runtime=runtime)

def redefines(self, other):
if isinstance(other, SubmoduleImportation):
Expand Down Expand Up @@ -366,11 +367,11 @@ class SubmoduleImportation(Importation):
name is also the same, to avoid false positives.
"""

def __init__(self, name, source):
def __init__(self, name, source, runtime=True):
# A dot should only appear in the name when it is a submodule import
assert '.' in name and (not source or isinstance(source, ast.Import))
package_name = name.split('.')[0]
super().__init__(package_name, source)
super().__init__(package_name, source, runtime=runtime)
self.fullName = name

def redefines(self, other):
Expand All @@ -388,7 +389,8 @@ def source_statement(self):

class ImportationFrom(Importation):

def __init__(self, name, source, module, real_name=None):
def __init__(
self, name, source, module, real_name=None, runtime=True):
self.module = module
self.real_name = real_name or name

Expand All @@ -397,7 +399,7 @@ def __init__(self, name, source, module, real_name=None):
else:
full_name = module + '.' + self.real_name

super().__init__(name, source, full_name)
super().__init__(name, source, full_name, runtime=runtime)

def __str__(self):
"""Return import full name with alias."""
Expand All @@ -417,8 +419,8 @@ def source_statement(self):
class StarImportation(Importation):
"""A binding created by a 'from x import *' statement."""

def __init__(self, name, source):
super().__init__('*', source)
def __init__(self, name, source, runtime=True):
super().__init__('*', source, runtime=runtime)
# Each star importation needs a unique name, and
# may not be the module name otherwise it will be deemed imported
self.name = name + '.*'
Expand Down Expand Up @@ -507,7 +509,7 @@ class ExportBinding(Binding):
C{__all__} will not have an unused import warning reported for them.
"""

def __init__(self, name, source, scope):
def __init__(self, name, source, scope, runtime=True):
if '__all__' in scope and isinstance(source, ast.AugAssign):
self.names = list(scope['__all__'].names)
else:
Expand Down Expand Up @@ -538,7 +540,7 @@ def _add_to_names(container):
# If not list concatenation
else:
break
super().__init__(name, source)
super().__init__(name, source, runtime=runtime)


class Scope(dict):
Expand Down Expand Up @@ -741,6 +743,7 @@ class Checker:
nodeDepth = 0
offset = None
_in_annotation = AnnotationState.NONE
_in_type_check_guard = False

builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
Expand Down Expand Up @@ -1009,9 +1012,11 @@ def addBinding(self, node, value):
# then assume the rebound name is used as a global or within a loop
value.used = self.scope[value.name].used

# don't treat annotations as assignments if there is an existing value
# in scope
if value.name not in self.scope or not isinstance(value, Annotation):
# always allow the first assignment or if not already a runtime value,
# but do not shadow an existing assignment with an annotation or non
# runtime value.
if (not existing or not existing.runtime or (
not isinstance(value, Annotation) and value.runtime)):
cur_scope_pos = -1
# As per PEP 572, use scope in which outermost generator is defined
while (
Expand Down Expand Up @@ -1077,12 +1082,18 @@ def handleNodeLoad(self, node, parent):
self.report(messages.InvalidPrintSyntax, node)

try:
scope[name].used = (self.scope, node)
n = scope[name]
if (not n.runtime and not (
self._in_type_check_guard
or self._in_annotation)):
self.report(messages.UndefinedName, node, name)
return

n.used = (self.scope, node)

# if the name of SubImportation is same as
# alias of other Importation and the alias
# is used, SubImportation also should be marked as used.
n = scope[name]
if isinstance(n, Importation) and n._has_alias():
try:
scope[n.fullName].used = (self.scope, node)
Expand Down Expand Up @@ -1145,12 +1156,13 @@ def handleNodeStore(self, node):
break

parent_stmt = self.getParent(node)
runtime = not self._in_type_check_guard
if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None:
binding = Annotation(name, node)
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
parent_stmt != node._pyflakes_parent and
not self.isLiteralTupleUnpacking(parent_stmt)):
binding = Binding(name, node)
binding = Binding(name, node, runtime=runtime)
elif (
name == '__all__' and
isinstance(self.scope, ModuleScope) and
Expand All @@ -1159,11 +1171,12 @@ def handleNodeStore(self, node):
(ast.Assign, ast.AugAssign, ast.AnnAssign)
)
):
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
binding = ExportBinding(
name, node._pyflakes_parent, self.scope, runtime=runtime)
elif isinstance(parent_stmt, ast.NamedExpr):
binding = NamedExprAssignment(name, node)
binding = NamedExprAssignment(name, node, runtime=runtime)
else:
binding = Assignment(name, node)
binding = Assignment(name, node, runtime=runtime)
self.addBinding(node, binding)

def handleNodeDelete(self, node):
Expand Down Expand Up @@ -1791,7 +1804,39 @@ def DICT(self, node):
def IF(self, node):
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
self.report(messages.IfTuple, node)
self.handleChildren(node)

self.handleNode(node.test, node)

# check if the body/orelse should be handled specially because it is
# a if TYPE_CHECKING guard.
test = node.test
reverse = False
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
test = test.operand
reverse = True

type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack)
orig = self._in_type_check_guard

# normalize body and orelse to a list
body, orelse = (
i if isinstance(i, list) else [i]
for i in (node.body, node.orelse))

# set the guard and handle the body
if type_checking and not reverse:
self._in_type_check_guard = True

for n in body:
self.handleNode(n, node)

# set the guard and handle the orelse
if type_checking:
self._in_type_check_guard = True if reverse else orig

for n in orelse:
self.handleNode(n, node)
self._in_type_check_guard = orig

IFEXP = IF

Expand Down Expand Up @@ -1903,7 +1948,10 @@ def FUNCTIONDEF(self, node):
for deco in node.decorator_list:
self.handleNode(deco, node)
self.LAMBDA(node)
self.addBinding(node, FunctionDefinition(node.name, node))
self.addBinding(
node,
FunctionDefinition(
node.name, node, runtime=not self._in_type_check_guard))
# doctest does not process doctest within a doctest,
# or in nested functions.
if (self.withDoctest and
Expand Down Expand Up @@ -1982,7 +2030,10 @@ def CLASSDEF(self, node):
self.deferFunction(lambda: self.handleDoctests(node))
for stmt in node.body:
self.handleNode(stmt, node)
self.addBinding(node, ClassDefinition(node.name, node))
self.addBinding(
node,
ClassDefinition(
node.name, node, runtime=not self._in_type_check_guard))

def AUGASSIGN(self, node):
self.handleNodeLoad(node.target, node)
Expand Down Expand Up @@ -2015,12 +2066,15 @@ def TUPLE(self, node):
LIST = TUPLE

def IMPORT(self, node):
runtime = not self._in_type_check_guard
for alias in node.names:
if '.' in alias.name and not alias.asname:
importation = SubmoduleImportation(alias.name, node)
importation = SubmoduleImportation(
alias.name, node, runtime=runtime)
else:
name = alias.asname or alias.name
importation = Importation(name, node, alias.name)
importation = Importation(
name, node, alias.name, runtime=runtime)
self.addBinding(node, importation)

def IMPORTFROM(self, node):
Expand All @@ -2032,6 +2086,7 @@ def IMPORTFROM(self, node):

module = ('.' * node.level) + (node.module or '')

runtime = not self._in_type_check_guard
for alias in node.names:
name = alias.asname or alias.name
if node.module == '__future__':
Expand All @@ -2049,10 +2104,10 @@ def IMPORTFROM(self, node):

self.scope.importStarred = True
self.report(messages.ImportStarUsed, node, module)
importation = StarImportation(module, node)
importation = StarImportation(module, node, runtime=runtime)
else:
importation = ImportationFrom(name, node,
module, alias.name)
importation = ImportationFrom(
name, node, module, alias.name, runtime=runtime)
self.addBinding(node, importation)

def TRY(self, node):
Expand Down
68 changes: 68 additions & 0 deletions pyflakes/test/test_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,57 @@ def f() -> T:
pass
""")

@skipIf(version_info < (3,), 'new in Python 3')
def test_typing_guard_import(self):
# T is imported for runtime use
self.flakes("""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from t import T
def f(x) -> T:
from t import T
assert isinstance(x, T)
return x
""")
# T is defined at runtime in one side of the if/else block
self.flakes("""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from t import T
else:
T = object
if not TYPE_CHECKING:
U = object
else:
from t import U
def f(x) -> Union[T, U]:
assert isinstance(x, (T, U))
return x
""")

@skipIf(version_info < (3,), 'new in Python 3')
def test_typing_guard_import_runtime_error(self):
# T and U are not bound for runtime use
self.flakes("""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from t import T
class U:
pass
def f(x) -> Union[T, U]:
assert isinstance(x, (T, U))
return x
""", m.UndefinedName, m.UndefinedName)

def test_typing_guard_for_protocol(self):
self.flakes("""
from typing import TYPE_CHECKING
Expand All @@ -659,6 +710,23 @@ def f() -> int:
pass
""")

def test_typing_guard_with_elif_branch(self):
# This test will not raise an error even though Protocol is not
# defined outside TYPE_CHECKING because Pyflakes does not do case
# analysis.
self.flakes("""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Protocol
elif False:
Protocol = object
else:
pass
class C(Protocol):
def f(): # type: () -> int
pass
""")

def test_typednames_correct_forward_ref(self):
self.flakes("""
from typing import TypedDict, List, NamedTuple
Expand Down

0 comments on commit 58506b3

Please sign in to comment.