From c20d8360aef3541691e0434c51bbf2fd241655f5 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 16 Feb 2021 15:51:25 -0800 Subject: [PATCH 1/3] Add failing test cases --- libcst/_nodes/tests/test_binary_op.py | 67 +++++++++++++++++++++++++++ libcst/tests/test_fuzz.py | 30 ++++++++++++ 2 files changed, 97 insertions(+) diff --git a/libcst/_nodes/tests/test_binary_op.py b/libcst/_nodes/tests/test_binary_op.py index 50f8ff79e..575e67e06 100644 --- a/libcst/_nodes/tests/test_binary_op.py +++ b/libcst/_nodes/tests/test_binary_op.py @@ -141,6 +141,36 @@ class BinaryOperationTest(CSTNodeTest): "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 13)), }, + # Make sure operators are left associative + { + "node": cst.BinaryOperation( + left=cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Add(), + right=cst.Name("bar"), + ), + operator=cst.Add(), + right=cst.Name("baz"), + ), + "code": "foo + bar + baz", + "parser": parse_expression, + "expected_position": None, + }, + # Except for Power which is right associative + { + "node": cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Power(), + right=cst.BinaryOperation( + left=cst.Name("bar"), + operator=cst.Power(), + right=cst.Name("baz"), + ), + ), + "code": "foo ** bar ** baz", + "parser": parse_expression, + "expected_position": None, + }, ) ) def test_valid(self, **kwargs: Any) -> None: @@ -174,3 +204,40 @@ def test_valid(self, **kwargs: Any) -> None: ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) + + + @data_provider( + ( + # Make sure operands are implicitly parenthesized + { + "node": cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Multiply(), + right=cst.BinaryOperation( + left=cst.Name("bar"), + operator=cst.Multiply(), + right=cst.Name("baz"), + ), + ), + "code": "foo * (bar * baz)", + }, + # But only when necessary + { + "node": cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Add(), + right=cst.BinaryOperation( + left=cst.Name("bar"), + operator=cst.Multiply(), + right=cst.Name("baz"), + ), + ), + "code": "foo + bar * baz", + }, + ) + ) + def test_implicit_parens(self, node: cst.CSTNode, code: str) -> None: + self.assertEqual( + cst.Module([]).code_for_node(node), + code + ) diff --git a/libcst/tests/test_fuzz.py b/libcst/tests/test_fuzz.py index 590449c7b..8c8676881 100644 --- a/libcst/tests/test_fuzz.py +++ b/libcst/tests/test_fuzz.py @@ -14,11 +14,25 @@ import os import unittest from datetime import timedelta +from typing import cast import hypothesis from hypothesmith import from_grammar import libcst +from libcst import CSTTransformer + + +class ParensRemover(CSTTransformer): + def on_leave(self, + original_node: libcst.CSTNode, + updated_node: libcst.CSTNode) -> libcst.CSTNode: + if isinstance(updated_node, libcst.BaseExpression): + return updated_node.with_changes( + lpar=(), + rpar=(), + ) + return updated_node # If in doubt, you should use these "unit test" settings. They tune the timeouts @@ -92,6 +106,7 @@ def test_parsing_compilable_expression_strings(self, source_code: str) -> None: self.verify_identical_asts( source_code, libcst.Module([]).code_for_node(tree), mode="eval" ) + self.verify_round_trip_without_parens(tree) except libcst.ParserSyntaxError: # Unlike statements, which allow us to strip trailing whitespace, # expressions require no whitespace or newlines. Its much more work @@ -180,6 +195,21 @@ def reject_unsupported_code(source_code: str) -> None: hypothesis.reject() + def verify_round_trip_without_parens(self, original_node: libcst.BaseExpression) -> None: + """ + Verifies that removing parens from an expression does not change the + code. E.g. `(1+2)*3` with the parens removed does not become `1+2*3`. + """ + # Technically could return RemoveFromParent but we know it wont + node = cast( + libcst.BaseExpression, + original_node.visit(ParensRemover()) + ) + new_code = libcst.Module([]).code_for_node(node) + new_node = libcst.parse_expression(new_code) + self.assertTrue(node.deep_equals(new_node)) + + if __name__ == "__main__": hypothesis.settings.load_profile("settings-for-fuzzing") unittest.main() From c9767667292a72a206bf92d577f9f2e39d31dfa4 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 16 Feb 2021 15:53:33 -0800 Subject: [PATCH 2/3] Add _PrecedenceNode --- libcst/_nodes/expression.py | 131 ++++++++++++++++++++++++++++-------- libcst/_nodes/statement.py | 8 +-- 2 files changed, 107 insertions(+), 32 deletions(-) diff --git a/libcst/_nodes/expression.py b/libcst/_nodes/expression.py index 25b372d10..76d396710 100644 --- a/libcst/_nodes/expression.py +++ b/libcst/_nodes/expression.py @@ -9,7 +9,7 @@ from ast import literal_eval from contextlib import contextmanager from dataclasses import dataclass, field -from enum import Enum, auto +from enum import Enum, auto, IntEnum from tokenize import ( Floatnumber as FLOATNUMBER_RE, Imagnumber as IMAGNUMBER_RE, @@ -43,11 +43,36 @@ IsNot, Not, NotIn, + Power, ) from libcst._nodes.whitespace import BaseParenthesizableWhitespace, SimpleWhitespace from libcst._visitors import CSTVisitorT +class _Precedence(IntEnum): + NamedExpr = Yield = auto() + Lambda = auto() + IfExp = auto() + Or = auto() + And = auto() + Not = auto() + Comparison = auto() + BitOr = auto() + BitXor = auto() + BitAnd = auto() + LeftShift = RightShift = auto() + Add = Subtract = auto() + Multiply = MatrixMultiply = Divide = FloorDivide = Modulo = auto() + # The docs claim that Power has higher precedence than unary operators + # this is innacurate. Power binds more tightly on the right than unary + # operators and less tightly on the left. In other words Power is right + # associative with same precedence as the unary operators. + Plus = Minus = BitInvert = Power = auto() + Await = auto() + Attribute = Subscript = Call = auto() + Atom = auto() + + @add_slots @dataclass(frozen=True) class LeftSquareBracket(CSTNode): @@ -234,6 +259,12 @@ def _validate(self) -> None: if len(self.lpar) != len(self.rpar): raise CSTValidationError("Cannot have unbalanced parens.") + + @property + def _is_parenthesized(self) -> bool: + return len(self.lpar) > 0 and len(self.rpar) > 0 + + @contextmanager def _parenthesize(self, state: CodegenState) -> Generator[None, None, None]: for lpar in self.lpar: @@ -249,7 +280,33 @@ class ExpressionPosition(Enum): RIGHT = auto() -class BaseExpression(_BaseParenthesizedNode, ABC): +class _PrecedenceNode: + """ + A base class for nodes which may need to parenthesize their children + based on operator precedence. + """ + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.__class__.__name__, _Precedence.Atom) + + + def _parenthesize_child(self, child: "BaseExpression", state: CodegenState, wrap_ties: bool=False) -> None: + """ + A helper function that parenthesizes child if has lower precedence than self. + """ + if not child._is_parenthesized: + child_p = child._precedence + self_p = self._precedence + + if (child_p <= self_p) if wrap_ties else (child_p < self_p): + LeftParen()._codegen(state) + child._codegen(state) + RightParen()._codegen(state) + return + child._codegen(state) + + +class BaseExpression(_BaseParenthesizedNode, _PrecedenceNode, ABC): """ An base class for all expressions. :class:`BaseExpression` contains no fields. """ @@ -263,8 +320,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: dictionaries and lists will override this to signifiy that they're always safe. """ - - return len(self.lpar) > 0 and len(self.rpar) > 0 + return self._is_parenthesized def _check_left_right_word_concatenation_safety( self, @@ -983,9 +1039,7 @@ def _validate(self) -> None: super(ConcatenatedString, self)._validate() # Strings that are concatenated cannot have parens. - if bool(self.left.lpar) or bool(self.left.rpar): - raise CSTValidationError("Cannot concatenate parenthesized strings.") - if bool(self.right.lpar) or bool(self.right.rpar): + if self.left._is_parenthesized or self.right._is_parenthesized: raise CSTValidationError("Cannot concatenate parenthesized strings.") # Cannot concatenate str and bytes @@ -1039,7 +1093,7 @@ def evaluated_value(self) -> Optional[str]: @add_slots @dataclass(frozen=True) -class ComparisonTarget(CSTNode): +class ComparisonTarget(CSTNode, _PrecedenceNode): """ A target for a :class:`Comparison`. Owns the comparison operator and the value to the right of the operator. @@ -1073,7 +1127,11 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ComparisonTarget def _codegen_impl(self, state: CodegenState) -> None: self.operator._codegen(state) - self.comparator._codegen(state) + self._parenthesize_child(self.comparator, state) + + @property + def _precedence(self) -> _Precedence: + return _Precedence.Comparison @add_slots @@ -1160,7 +1218,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Comparison": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.left._codegen(state) + self._parenthesize_child(self.left, state) for comp in self.comparisons: comp._codegen(state) @@ -1224,7 +1282,11 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.operator._codegen(state) - self.expression._codegen(state) + self._parenthesize_child(self.expression, state) + + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.operator.__class__.__name__) @add_slots @@ -1275,9 +1337,19 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.left._codegen(state) - self.operator._codegen(state) - self.right._codegen(state) + # Need to special case Power because it binds strangely + if isinstance(self.operator, Power): + self._parenthesize_child(self.left, state, wrap_ties=True) + self.operator._codegen(state) + self._parenthesize_child(self.right, state) + else: + self._parenthesize_child(self.left, state) + self.operator._codegen(state) + self._parenthesize_child(self.right, state, wrap_ties=True) + + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.operator.__class__.__name__) @add_slots @@ -1347,10 +1419,13 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.left._codegen(state) + self._parenthesize_child(self.left, state) self.operator._codegen(state) - self.right._codegen(state) + self._parenthesize_child(self.right, state, wrap_ties=True) + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.operator.__class__.__name__) @add_slots @dataclass(frozen=True) @@ -1404,7 +1479,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.value._codegen(state) + self._parenthesize_child(self.value, state) self.dot._codegen(state) self.attr._codegen(state) @@ -1578,7 +1653,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.value._codegen(state) + self._parenthesize_child(self.value, state) self.whitespace_after_value._codegen(state) self.lbracket._codegen(state) lastslice = len(self.slice) - 1 @@ -2101,7 +2176,7 @@ def _codegen_impl(self, state: CodegenState) -> None: whitespace_after_lambda._codegen(state) self.params._codegen(state) self.colon._codegen(state) - self.body._codegen(state) + self._parenthesize_child(self.body, state) @add_slots @@ -2346,7 +2421,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Call": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.func._codegen(state) + self._parenthesize_child(self.func, state) self.whitespace_after_func._codegen(state) state.add_token("(") self.whitespace_before_args._codegen(state) @@ -2397,7 +2472,7 @@ def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token("await") self.whitespace_after_await._codegen(state) - self.expression._codegen(state) + self._parenthesize_child(self.expression, state) @add_slots @@ -2490,15 +2565,15 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "IfExp": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.body._codegen(state) + self._parenthesize_child(self.body, state) self.whitespace_before_if._codegen(state) state.add_token("if") self.whitespace_after_if._codegen(state) - self.test._codegen(state) + self._parenthesize_child(self.test, state, wrap_ties=True) self.whitespace_before_else._codegen(state) state.add_token("else") self.whitespace_after_else._codegen(state) - self.orelse._codegen(state) + self._parenthesize_child(self.orelse, state, wrap_ties=True) @add_slots @@ -2904,7 +2979,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: isinstance(last_element.comma, Comma) or ( isinstance(last_element, StarredElement) - and len(last_element.rpar) > 0 + and last_element._is_parenthesized ) or last_element.value._safe_to_use_with_word_operator(position) ) @@ -2920,7 +2995,7 @@ def _validate(self) -> None: super(Tuple, self)._validate() if len(self.elements) == 0: - if len(self.lpar) == 0: # assumes len(lpar) == len(rpar), via superclass + if not self._is_parenthesized: raise CSTValidationError( "A zero-length tuple must be wrapped in parentheses." ) @@ -3696,8 +3771,8 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NamedExpr": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.target._codegen(state) + self._parenthesize_child(self.target, state) self.whitespace_before_walrus._codegen(state) state.add_token(":=") self.whitespace_after_walrus._codegen(state) - self.value._codegen(state) + self._parenthesize_child(self.value, state, wrap_ties=True) diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 6a831b85b..f3c3c7cb8 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -1453,7 +1453,7 @@ class Decorator(CSTNode): def _validate(self) -> None: decorator = self.decorator - if len(decorator.lpar) > 0 or len(decorator.rpar) > 0: + if decorator._is_parenthesized: raise CSTValidationError( "Cannot have parens around decorator in a Decorator." ) @@ -1579,7 +1579,7 @@ class FunctionDef(BaseCompoundStatement): whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("") def _validate(self) -> None: - if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + if self.name._is_parenthesized: raise CSTValidationError("Cannot have parens around Name in a FunctionDef.") if self.whitespace_after_def.empty: raise CSTValidationError( @@ -1709,7 +1709,7 @@ def _validate_whitespace(self) -> None: ) def _validate_parens(self) -> None: - if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + if self.name._is_parenthesized: raise CSTValidationError("Cannot have parens around Name in a ClassDef.") if isinstance(self.lpar, MaybeSentinel) and isinstance(self.rpar, RightParen): raise CSTValidationError( @@ -2283,7 +2283,7 @@ class NameItem(CSTNode): def _validate(self) -> None: # No parens around names here - if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + if self.name._is_parenthesized: raise CSTValidationError("Cannot have parens around names in NameItem.") def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NameItem": From 2c6393d0deefda2fde5fea2fa39155badfa60564 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 16 Feb 2021 16:11:15 -0800 Subject: [PATCH 3/3] autofix --- libcst/_nodes/expression.py | 13 +++++++------ libcst/_nodes/tests/test_binary_op.py | 8 ++------ libcst/tests/test_fuzz.py | 16 +++++++--------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/libcst/_nodes/expression.py b/libcst/_nodes/expression.py index 76d396710..bf11c8432 100644 --- a/libcst/_nodes/expression.py +++ b/libcst/_nodes/expression.py @@ -9,7 +9,7 @@ from ast import literal_eval from contextlib import contextmanager from dataclasses import dataclass, field -from enum import Enum, auto, IntEnum +from enum import Enum, IntEnum, auto from tokenize import ( Floatnumber as FLOATNUMBER_RE, Imagnumber as IMAGNUMBER_RE, @@ -259,12 +259,10 @@ def _validate(self) -> None: if len(self.lpar) != len(self.rpar): raise CSTValidationError("Cannot have unbalanced parens.") - @property def _is_parenthesized(self) -> bool: return len(self.lpar) > 0 and len(self.rpar) > 0 - @contextmanager def _parenthesize(self, state: CodegenState) -> Generator[None, None, None]: for lpar in self.lpar: @@ -285,12 +283,14 @@ class _PrecedenceNode: A base class for nodes which may need to parenthesize their children based on operator precedence. """ + @property def _precedence(self) -> _Precedence: return getattr(_Precedence, self.__class__.__name__, _Precedence.Atom) - - def _parenthesize_child(self, child: "BaseExpression", state: CodegenState, wrap_ties: bool=False) -> None: + def _parenthesize_child( + self, child: "BaseExpression", state: CodegenState, wrap_ties: bool = False + ) -> None: """ A helper function that parenthesizes child if has lower precedence than self. """ @@ -1039,7 +1039,7 @@ def _validate(self) -> None: super(ConcatenatedString, self)._validate() # Strings that are concatenated cannot have parens. - if self.left._is_parenthesized or self.right._is_parenthesized: + if self.left._is_parenthesized or self.right._is_parenthesized: raise CSTValidationError("Cannot concatenate parenthesized strings.") # Cannot concatenate str and bytes @@ -1427,6 +1427,7 @@ def _codegen_impl(self, state: CodegenState) -> None: def _precedence(self) -> _Precedence: return getattr(_Precedence, self.operator.__class__.__name__) + @add_slots @dataclass(frozen=True) class Attribute(BaseAssignTargetExpression, BaseDelTargetExpression): diff --git a/libcst/_nodes/tests/test_binary_op.py b/libcst/_nodes/tests/test_binary_op.py index 575e67e06..1632d8d53 100644 --- a/libcst/_nodes/tests/test_binary_op.py +++ b/libcst/_nodes/tests/test_binary_op.py @@ -205,10 +205,9 @@ def test_valid(self, **kwargs: Any) -> None: def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) - @data_provider( ( - # Make sure operands are implicitly parenthesized + # Make sure operands are implicitly parenthesized { "node": cst.BinaryOperation( left=cst.Name("foo"), @@ -237,7 +236,4 @@ def test_invalid(self, **kwargs: Any) -> None: ) ) def test_implicit_parens(self, node: cst.CSTNode, code: str) -> None: - self.assertEqual( - cst.Module([]).code_for_node(node), - code - ) + self.assertEqual(cst.Module([]).code_for_node(node), code) diff --git a/libcst/tests/test_fuzz.py b/libcst/tests/test_fuzz.py index 8c8676881..06fe306b1 100644 --- a/libcst/tests/test_fuzz.py +++ b/libcst/tests/test_fuzz.py @@ -24,9 +24,9 @@ class ParensRemover(CSTTransformer): - def on_leave(self, - original_node: libcst.CSTNode, - updated_node: libcst.CSTNode) -> libcst.CSTNode: + def on_leave( + self, original_node: libcst.CSTNode, updated_node: libcst.CSTNode + ) -> libcst.CSTNode: if isinstance(updated_node, libcst.BaseExpression): return updated_node.with_changes( lpar=(), @@ -194,17 +194,15 @@ def reject_unsupported_code(source_code: str) -> None: # round trip perfectly, reject this. hypothesis.reject() - - def verify_round_trip_without_parens(self, original_node: libcst.BaseExpression) -> None: + def verify_round_trip_without_parens( + self, original_node: libcst.BaseExpression + ) -> None: """ Verifies that removing parens from an expression does not change the code. E.g. `(1+2)*3` with the parens removed does not become `1+2*3`. """ # Technically could return RemoveFromParent but we know it wont - node = cast( - libcst.BaseExpression, - original_node.visit(ParensRemover()) - ) + node = cast(libcst.BaseExpression, original_node.visit(ParensRemover())) new_code = libcst.Module([]).code_for_node(node) new_node = libcst.parse_expression(new_code) self.assertTrue(node.deep_equals(new_node))