From 58506b34c28a50f84dc306fff168d1451334bf34 Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Tue, 23 Mar 2021 13:34:22 -0700 Subject: [PATCH] track `if typing.TYPE_CHECKING` to warn about non runtime bindings When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context. --- pyflakes/checker.py | 115 ++++++++++++++++++------- pyflakes/test/test_type_annotations.py | 68 +++++++++++++++ 2 files changed, 153 insertions(+), 30 deletions(-) diff --git a/pyflakes/checker.py b/pyflakes/checker.py index e654afa3..5c538070 100644 --- a/pyflakes/checker.py +++ b/pyflakes/checker.py @@ -239,10 +239,11 @@ class Binding: the node that this binding was last used. """ - def __init__(self, name, source): + def __init__(self, name, source, runtime=True): self.name = name self.source = source self.used = False + self.runtime = runtime def __str__(self): return self.name @@ -273,8 +274,8 @@ def redefines(self, other): class Builtin(Definition): """A definition created for all Python builtins.""" - def __init__(self, name): - super().__init__(name, None) + def __init__(self, name, runtime=True): + super().__init__(name, None, runtime=runtime) def __repr__(self): return '<{} object {!r} at 0x{:x}>'.format( @@ -318,10 +319,10 @@ class Importation(Definition): @type fullName: C{str} """ - def __init__(self, name, source, full_name=None): + def __init__(self, name, source, full_name=None, runtime=True): self.fullName = full_name or name self.redefined = [] - super().__init__(name, source) + super().__init__(name, source, runtime=runtime) def redefines(self, other): if isinstance(other, SubmoduleImportation): @@ -366,11 +367,11 @@ class SubmoduleImportation(Importation): name is also the same, to avoid false positives. """ - def __init__(self, name, source): + def __init__(self, name, source, runtime=True): # A dot should only appear in the name when it is a submodule import assert '.' in name and (not source or isinstance(source, ast.Import)) package_name = name.split('.')[0] - super().__init__(package_name, source) + super().__init__(package_name, source, runtime=runtime) self.fullName = name def redefines(self, other): @@ -388,7 +389,8 @@ def source_statement(self): class ImportationFrom(Importation): - def __init__(self, name, source, module, real_name=None): + def __init__( + self, name, source, module, real_name=None, runtime=True): self.module = module self.real_name = real_name or name @@ -397,7 +399,7 @@ def __init__(self, name, source, module, real_name=None): else: full_name = module + '.' + self.real_name - super().__init__(name, source, full_name) + super().__init__(name, source, full_name, runtime=runtime) def __str__(self): """Return import full name with alias.""" @@ -417,8 +419,8 @@ def source_statement(self): class StarImportation(Importation): """A binding created by a 'from x import *' statement.""" - def __init__(self, name, source): - super().__init__('*', source) + def __init__(self, name, source, runtime=True): + super().__init__('*', source, runtime=runtime) # Each star importation needs a unique name, and # may not be the module name otherwise it will be deemed imported self.name = name + '.*' @@ -507,7 +509,7 @@ class ExportBinding(Binding): C{__all__} will not have an unused import warning reported for them. """ - def __init__(self, name, source, scope): + def __init__(self, name, source, scope, runtime=True): if '__all__' in scope and isinstance(source, ast.AugAssign): self.names = list(scope['__all__'].names) else: @@ -538,7 +540,7 @@ def _add_to_names(container): # If not list concatenation else: break - super().__init__(name, source) + super().__init__(name, source, runtime=runtime) class Scope(dict): @@ -741,6 +743,7 @@ class Checker: nodeDepth = 0 offset = None _in_annotation = AnnotationState.NONE + _in_type_check_guard = False builtIns = set(builtin_vars).union(_MAGIC_GLOBALS) _customBuiltIns = os.environ.get('PYFLAKES_BUILTINS') @@ -1009,9 +1012,11 @@ def addBinding(self, node, value): # then assume the rebound name is used as a global or within a loop value.used = self.scope[value.name].used - # don't treat annotations as assignments if there is an existing value - # in scope - if value.name not in self.scope or not isinstance(value, Annotation): + # always allow the first assignment or if not already a runtime value, + # but do not shadow an existing assignment with an annotation or non + # runtime value. + if (not existing or not existing.runtime or ( + not isinstance(value, Annotation) and value.runtime)): cur_scope_pos = -1 # As per PEP 572, use scope in which outermost generator is defined while ( @@ -1077,12 +1082,18 @@ def handleNodeLoad(self, node, parent): self.report(messages.InvalidPrintSyntax, node) try: - scope[name].used = (self.scope, node) + n = scope[name] + if (not n.runtime and not ( + self._in_type_check_guard + or self._in_annotation)): + self.report(messages.UndefinedName, node, name) + return + + n.used = (self.scope, node) # if the name of SubImportation is same as # alias of other Importation and the alias # is used, SubImportation also should be marked as used. - n = scope[name] if isinstance(n, Importation) and n._has_alias(): try: scope[n.fullName].used = (self.scope, node) @@ -1145,12 +1156,13 @@ def handleNodeStore(self, node): break parent_stmt = self.getParent(node) + runtime = not self._in_type_check_guard if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None: binding = Annotation(name, node) elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or ( parent_stmt != node._pyflakes_parent and not self.isLiteralTupleUnpacking(parent_stmt)): - binding = Binding(name, node) + binding = Binding(name, node, runtime=runtime) elif ( name == '__all__' and isinstance(self.scope, ModuleScope) and @@ -1159,11 +1171,12 @@ def handleNodeStore(self, node): (ast.Assign, ast.AugAssign, ast.AnnAssign) ) ): - binding = ExportBinding(name, node._pyflakes_parent, self.scope) + binding = ExportBinding( + name, node._pyflakes_parent, self.scope, runtime=runtime) elif isinstance(parent_stmt, ast.NamedExpr): - binding = NamedExprAssignment(name, node) + binding = NamedExprAssignment(name, node, runtime=runtime) else: - binding = Assignment(name, node) + binding = Assignment(name, node, runtime=runtime) self.addBinding(node, binding) def handleNodeDelete(self, node): @@ -1791,7 +1804,39 @@ def DICT(self, node): def IF(self, node): if isinstance(node.test, ast.Tuple) and node.test.elts != []: self.report(messages.IfTuple, node) - self.handleChildren(node) + + self.handleNode(node.test, node) + + # check if the body/orelse should be handled specially because it is + # a if TYPE_CHECKING guard. + test = node.test + reverse = False + if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not): + test = test.operand + reverse = True + + type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack) + orig = self._in_type_check_guard + + # normalize body and orelse to a list + body, orelse = ( + i if isinstance(i, list) else [i] + for i in (node.body, node.orelse)) + + # set the guard and handle the body + if type_checking and not reverse: + self._in_type_check_guard = True + + for n in body: + self.handleNode(n, node) + + # set the guard and handle the orelse + if type_checking: + self._in_type_check_guard = True if reverse else orig + + for n in orelse: + self.handleNode(n, node) + self._in_type_check_guard = orig IFEXP = IF @@ -1903,7 +1948,10 @@ def FUNCTIONDEF(self, node): for deco in node.decorator_list: self.handleNode(deco, node) self.LAMBDA(node) - self.addBinding(node, FunctionDefinition(node.name, node)) + self.addBinding( + node, + FunctionDefinition( + node.name, node, runtime=not self._in_type_check_guard)) # doctest does not process doctest within a doctest, # or in nested functions. if (self.withDoctest and @@ -1982,7 +2030,10 @@ def CLASSDEF(self, node): self.deferFunction(lambda: self.handleDoctests(node)) for stmt in node.body: self.handleNode(stmt, node) - self.addBinding(node, ClassDefinition(node.name, node)) + self.addBinding( + node, + ClassDefinition( + node.name, node, runtime=not self._in_type_check_guard)) def AUGASSIGN(self, node): self.handleNodeLoad(node.target, node) @@ -2015,12 +2066,15 @@ def TUPLE(self, node): LIST = TUPLE def IMPORT(self, node): + runtime = not self._in_type_check_guard for alias in node.names: if '.' in alias.name and not alias.asname: - importation = SubmoduleImportation(alias.name, node) + importation = SubmoduleImportation( + alias.name, node, runtime=runtime) else: name = alias.asname or alias.name - importation = Importation(name, node, alias.name) + importation = Importation( + name, node, alias.name, runtime=runtime) self.addBinding(node, importation) def IMPORTFROM(self, node): @@ -2032,6 +2086,7 @@ def IMPORTFROM(self, node): module = ('.' * node.level) + (node.module or '') + runtime = not self._in_type_check_guard for alias in node.names: name = alias.asname or alias.name if node.module == '__future__': @@ -2049,10 +2104,10 @@ def IMPORTFROM(self, node): self.scope.importStarred = True self.report(messages.ImportStarUsed, node, module) - importation = StarImportation(module, node) + importation = StarImportation(module, node, runtime=runtime) else: - importation = ImportationFrom(name, node, - module, alias.name) + importation = ImportationFrom( + name, node, module, alias.name, runtime=runtime) self.addBinding(node, importation) def TRY(self, node): diff --git a/pyflakes/test/test_type_annotations.py b/pyflakes/test/test_type_annotations.py index 396d676f..28a8a99e 100644 --- a/pyflakes/test/test_type_annotations.py +++ b/pyflakes/test/test_type_annotations.py @@ -645,6 +645,57 @@ def f() -> T: pass """) + @skipIf(version_info < (3,), 'new in Python 3') + def test_typing_guard_import(self): + # T is imported for runtime use + self.flakes(""" + from typing import TYPE_CHECKING + + if TYPE_CHECKING: + from t import T + + def f(x) -> T: + from t import T + + assert isinstance(x, T) + return x + """) + # T is defined at runtime in one side of the if/else block + self.flakes(""" + from typing import TYPE_CHECKING, Union + + if TYPE_CHECKING: + from t import T + else: + T = object + + if not TYPE_CHECKING: + U = object + else: + from t import U + + def f(x) -> Union[T, U]: + assert isinstance(x, (T, U)) + return x + """) + + @skipIf(version_info < (3,), 'new in Python 3') + def test_typing_guard_import_runtime_error(self): + # T and U are not bound for runtime use + self.flakes(""" + from typing import TYPE_CHECKING, Union + + if TYPE_CHECKING: + from t import T + + class U: + pass + + def f(x) -> Union[T, U]: + assert isinstance(x, (T, U)) + return x + """, m.UndefinedName, m.UndefinedName) + def test_typing_guard_for_protocol(self): self.flakes(""" from typing import TYPE_CHECKING @@ -659,6 +710,23 @@ def f() -> int: pass """) + def test_typing_guard_with_elif_branch(self): + # This test will not raise an error even though Protocol is not + # defined outside TYPE_CHECKING because Pyflakes does not do case + # analysis. + self.flakes(""" + from typing import TYPE_CHECKING + if TYPE_CHECKING: + from typing import Protocol + elif False: + Protocol = object + else: + pass + class C(Protocol): + def f(): # type: () -> int + pass + """) + def test_typednames_correct_forward_ref(self): self.flakes(""" from typing import TypedDict, List, NamedTuple