Skip to content

Commit

Permalink
track if typing.TYPE_CHECKING to warn about non runtime bindings
Browse files Browse the repository at this point in the history
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks
the bound names will not be available at runtime and may cause errors
when used in the following way::

  import typing

  if typing.TYPE_CHECKING:
    from module import Type  # some slow import or circular reference

  def method(value) -> Type:  # the import is needed by the type checker
    assert isinstance(value, Type)  # this is a runtime error

This change allows pyflakes to track what names are bound for runtime
use, and allows it to warn when a non runtime name is used in a runtime
context.
  • Loading branch information
terencehonles committed Mar 23, 2021
1 parent 40e6dc2 commit a795822
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 29 deletions.
116 changes: 87 additions & 29 deletions pyflakes/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,11 @@ class Binding(object):
the node that this binding was last used.
"""

def __init__(self, name, source):
def __init__(self, name, source, runtime=True):
self.name = name
self.source = source
self.used = False
self.runtime = runtime

def __str__(self):
return self.name
Expand Down Expand Up @@ -391,10 +392,10 @@ class Importation(Definition):
@type fullName: C{str}
"""

def __init__(self, name, source, full_name=None):
def __init__(self, name, source, full_name=None, runtime=True):
self.fullName = full_name or name
self.redefined = []
super(Importation, self).__init__(name, source)
super(Importation, self).__init__(name, source, runtime=runtime)

def redefines(self, other):
if isinstance(other, SubmoduleImportation):
Expand Down Expand Up @@ -439,11 +440,12 @@ class SubmoduleImportation(Importation):
name is also the same, to avoid false positives.
"""

def __init__(self, name, source):
def __init__(self, name, source, runtime=True):
# A dot should only appear in the name when it is a submodule import
assert '.' in name and (not source or isinstance(source, ast.Import))
package_name = name.split('.')[0]
super(SubmoduleImportation, self).__init__(package_name, source)
super(SubmoduleImportation, self).__init__(
package_name, source, runtime=runtime)
self.fullName = name

def redefines(self, other):
Expand All @@ -461,7 +463,8 @@ def source_statement(self):

class ImportationFrom(Importation):

def __init__(self, name, source, module, real_name=None):
def __init__(
self, name, source, module, real_name=None, runtime=True):
self.module = module
self.real_name = real_name or name

Expand All @@ -470,7 +473,8 @@ def __init__(self, name, source, module, real_name=None):
else:
full_name = module + '.' + self.real_name

super(ImportationFrom, self).__init__(name, source, full_name)
super(ImportationFrom, self).__init__(
name, source, full_name, runtime=runtime)

def __str__(self):
"""Return import full name with alias."""
Expand All @@ -492,8 +496,8 @@ def source_statement(self):
class StarImportation(Importation):
"""A binding created by a 'from x import *' statement."""

def __init__(self, name, source):
super(StarImportation, self).__init__('*', source)
def __init__(self, name, source, runtime=True):
super(StarImportation, self).__init__('*', source, runtime=runtime)
# Each star importation needs a unique name, and
# may not be the module name otherwise it will be deemed imported
self.name = name + '.*'
Expand Down Expand Up @@ -572,7 +576,7 @@ class ExportBinding(Binding):
C{__all__} will not have an unused import warning reported for them.
"""

def __init__(self, name, source, scope):
def __init__(self, name, source, scope, runtime=True):
if '__all__' in scope and isinstance(source, ast.AugAssign):
self.names = list(scope['__all__'].names)
else:
Expand Down Expand Up @@ -603,7 +607,7 @@ def _add_to_names(container):
# If not list concatenation
else:
break
super(ExportBinding, self).__init__(name, source)
super(ExportBinding, self).__init__(name, source, runtime=runtime)


class Scope(dict):
Expand Down Expand Up @@ -867,6 +871,7 @@ class Checker(object):
traceTree = False
_in_annotation = AnnotationState.NONE
_in_deferred = False
_in_type_check_guard = False

builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
Expand Down Expand Up @@ -1140,9 +1145,11 @@ def addBinding(self, node, value):
# then assume the rebound name is used as a global or within a loop
value.used = self.scope[value.name].used

# don't treat annotations as assignments if there is an existing value
# in scope
if value.name not in self.scope or not isinstance(value, Annotation):
# always allow the first assignment or if not already a runtime value,
# but do not shadow an existing assignment with an annotation or non
# runtime value.
if (not existing or not existing.runtime or (
not isinstance(value, Annotation) and value.runtime)):
self.scope[value.name] = value

def _unknown_handler(self, node):
Expand Down Expand Up @@ -1201,12 +1208,18 @@ def handleNodeLoad(self, node):
self.report(messages.InvalidPrintSyntax, node)

try:
scope[name].used = (self.scope, node)
n = scope[name]
if (not n.runtime and not (
self._in_type_check_guard
or self._in_annotation)):
self.report(messages.UndefinedName, node, name)
return

n.used = (self.scope, node)

# if the name of SubImportation is same as
# alias of other Importation and the alias
# is used, SubImportation also should be marked as used.
n = scope[name]
if isinstance(n, Importation) and n._has_alias():
try:
scope[n.fullName].used = (self.scope, node)
Expand Down Expand Up @@ -1269,18 +1282,20 @@ def handleNodeStore(self, node):
break

parent_stmt = self.getParent(node)
runtime = not self._in_type_check_guard
if isinstance(parent_stmt, ANNASSIGN_TYPES) and parent_stmt.value is None:
binding = Annotation(name, node)
binding = Annotation(name, node, runtime=runtime)
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
parent_stmt != node._pyflakes_parent and
not self.isLiteralTupleUnpacking(parent_stmt)):
binding = Binding(name, node)
binding = Binding(name, node, runtime=runtime)
elif name == '__all__' and isinstance(self.scope, ModuleScope):
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
binding = ExportBinding(
name, node._pyflakes_parent, self.scope, runtime=runtime)
elif PY2 and isinstance(getattr(node, 'ctx', None), ast.Param):
binding = Argument(name, self.getScopeNode(node))
binding = Argument(name, self.getScopeNode(node), runtime=runtime)
else:
binding = Assignment(name, node)
binding = Assignment(name, node, runtime=runtime)
self.addBinding(node, binding)

def handleNodeDelete(self, node):
Expand Down Expand Up @@ -1969,7 +1984,40 @@ def DICT(self, node):
def IF(self, node):
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
self.report(messages.IfTuple, node)
self.handleChildren(node)

self._handle_type_comments(node)
self.handleNode(node.test, node)

# check if the body/orelse should be handled specially because it is
# a if TYPE_CHECKING guard.
test = node.test
reverse = False
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
test = test.operand
reverse = True

type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack)
orig = self._in_type_check_guard

# normalize body and orelse to a list
body, orelse = (
i if isinstance(i, list) else [i]
for i in (node.body, node.orelse))

# set the guard and handle the body
if type_checking and not reverse:
self._in_type_check_guard = True

for n in body:
self.handleNode(n, node)

# set the guard and handle the orelse
if type_checking:
self._in_type_check_guard = True if reverse else orig

for n in orelse:
self.handleNode(n, node)
self._in_type_check_guard = orig

IFEXP = IF

Expand Down Expand Up @@ -2092,7 +2140,10 @@ def FUNCTIONDEF(self, node):
for deco in node.decorator_list:
self.handleNode(deco, node)
self.LAMBDA(node)
self.addBinding(node, FunctionDefinition(node.name, node))
self.addBinding(
node,
FunctionDefinition(
node.name, node, runtime=not self._in_type_check_guard))
# doctest does not process doctest within a doctest,
# or in nested functions.
if (self.withDoctest and
Expand Down Expand Up @@ -2217,7 +2268,10 @@ def CLASSDEF(self, node):
for stmt in node.body:
self.handleNode(stmt, node)
self.popScope()
self.addBinding(node, ClassDefinition(node.name, node))
self.addBinding(
node,
ClassDefinition(
node.name, node, runtime=not self._in_type_check_guard))

def AUGASSIGN(self, node):
self.handleNodeLoad(node.target)
Expand Down Expand Up @@ -2250,12 +2304,15 @@ def TUPLE(self, node):
LIST = TUPLE

def IMPORT(self, node):
runtime = not self._in_type_check_guard
for alias in node.names:
if '.' in alias.name and not alias.asname:
importation = SubmoduleImportation(alias.name, node)
importation = SubmoduleImportation(
alias.name, node, runtime=runtime)
else:
name = alias.asname or alias.name
importation = Importation(name, node, alias.name)
importation = Importation(
name, node, alias.name, runtime=runtime)
self.addBinding(node, importation)

def IMPORTFROM(self, node):
Expand All @@ -2268,6 +2325,7 @@ def IMPORTFROM(self, node):

module = ('.' * node.level) + (node.module or '')

runtime = not self._in_type_check_guard
for alias in node.names:
name = alias.asname or alias.name
if node.module == '__future__':
Expand All @@ -2286,10 +2344,10 @@ def IMPORTFROM(self, node):

self.scope.importStarred = True
self.report(messages.ImportStarUsed, node, module)
importation = StarImportation(module, node)
importation = StarImportation(module, node, runtime=runtime)
else:
importation = ImportationFrom(name, node,
module, alias.name)
importation = ImportationFrom(
name, node, module, alias.name, runtime=runtime)
self.addBinding(node, importation)

def TRY(self, node):
Expand Down
68 changes: 68 additions & 0 deletions pyflakes/test/test_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,57 @@ def f(): # type: () -> T
pass
""")

@skipIf(version_info < (3,), 'new in Python 3')
def test_typing_guard_import(self):
# T is imported for runtime use
self.flakes("""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from t import T
def f(x) -> T:
from t import T
assert isinstance(x, T)
return x
""")
# T is defined at runtime in one side of the if/else block
self.flakes("""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from t import T
else:
T = object
if not TYPE_CHECKING:
U = object
else:
from t import U
def f(x) -> Union[T, U]:
assert isinstance(x, (T, U))
return x
""")

@skipIf(version_info < (3,), 'new in Python 3')
def test_typing_guard_import_runtime_error(self):
# T and U are not bound for runtime use
self.flakes("""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from t import T
class U:
pass
def f(x) -> Union[T, U]:
assert isinstance(x, (T, U))
return x
""", m.UndefinedName, m.UndefinedName)

def test_typing_guard_for_protocol(self):
self.flakes("""
from typing import TYPE_CHECKING
Expand All @@ -696,6 +747,23 @@ def f(): # type: () -> int
pass
""")

def test_typing_guard_with_elif_branch(self):
# This test will not raise an error even though Protocol is not
# defined outside TYPE_CHECKING because Pyflakes does not do case
# analysis.
self.flakes("""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Protocol
elif False:
Protocol = object
else:
pass
class C(Protocol):
def f(): # type: () -> int
pass
""")

def test_typednames_correct_forward_ref(self):
self.flakes("""
from typing import TypedDict, List, NamedTuple
Expand Down

0 comments on commit a795822

Please sign in to comment.