diff --git a/libcst/_nodes/expression.py b/libcst/_nodes/expression.py index 25b372d10..bf11c8432 100644 --- a/libcst/_nodes/expression.py +++ b/libcst/_nodes/expression.py @@ -9,7 +9,7 @@ from ast import literal_eval from contextlib import contextmanager from dataclasses import dataclass, field -from enum import Enum, auto +from enum import Enum, IntEnum, auto from tokenize import ( Floatnumber as FLOATNUMBER_RE, Imagnumber as IMAGNUMBER_RE, @@ -43,11 +43,36 @@ IsNot, Not, NotIn, + Power, ) from libcst._nodes.whitespace import BaseParenthesizableWhitespace, SimpleWhitespace from libcst._visitors import CSTVisitorT +class _Precedence(IntEnum): + NamedExpr = Yield = auto() + Lambda = auto() + IfExp = auto() + Or = auto() + And = auto() + Not = auto() + Comparison = auto() + BitOr = auto() + BitXor = auto() + BitAnd = auto() + LeftShift = RightShift = auto() + Add = Subtract = auto() + Multiply = MatrixMultiply = Divide = FloorDivide = Modulo = auto() + # The docs claim that Power has higher precedence than unary operators + # this is innacurate. Power binds more tightly on the right than unary + # operators and less tightly on the left. In other words Power is right + # associative with same precedence as the unary operators. + Plus = Minus = BitInvert = Power = auto() + Await = auto() + Attribute = Subscript = Call = auto() + Atom = auto() + + @add_slots @dataclass(frozen=True) class LeftSquareBracket(CSTNode): @@ -234,6 +259,10 @@ def _validate(self) -> None: if len(self.lpar) != len(self.rpar): raise CSTValidationError("Cannot have unbalanced parens.") + @property + def _is_parenthesized(self) -> bool: + return len(self.lpar) > 0 and len(self.rpar) > 0 + @contextmanager def _parenthesize(self, state: CodegenState) -> Generator[None, None, None]: for lpar in self.lpar: @@ -249,7 +278,35 @@ class ExpressionPosition(Enum): RIGHT = auto() -class BaseExpression(_BaseParenthesizedNode, ABC): +class _PrecedenceNode: + """ + A base class for nodes which may need to parenthesize their children + based on operator precedence. + """ + + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.__class__.__name__, _Precedence.Atom) + + def _parenthesize_child( + self, child: "BaseExpression", state: CodegenState, wrap_ties: bool = False + ) -> None: + """ + A helper function that parenthesizes child if has lower precedence than self. + """ + if not child._is_parenthesized: + child_p = child._precedence + self_p = self._precedence + + if (child_p <= self_p) if wrap_ties else (child_p < self_p): + LeftParen()._codegen(state) + child._codegen(state) + RightParen()._codegen(state) + return + child._codegen(state) + + +class BaseExpression(_BaseParenthesizedNode, _PrecedenceNode, ABC): """ An base class for all expressions. :class:`BaseExpression` contains no fields. """ @@ -263,8 +320,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: dictionaries and lists will override this to signifiy that they're always safe. """ - - return len(self.lpar) > 0 and len(self.rpar) > 0 + return self._is_parenthesized def _check_left_right_word_concatenation_safety( self, @@ -983,9 +1039,7 @@ def _validate(self) -> None: super(ConcatenatedString, self)._validate() # Strings that are concatenated cannot have parens. - if bool(self.left.lpar) or bool(self.left.rpar): - raise CSTValidationError("Cannot concatenate parenthesized strings.") - if bool(self.right.lpar) or bool(self.right.rpar): + if self.left._is_parenthesized or self.right._is_parenthesized: raise CSTValidationError("Cannot concatenate parenthesized strings.") # Cannot concatenate str and bytes @@ -1039,7 +1093,7 @@ def evaluated_value(self) -> Optional[str]: @add_slots @dataclass(frozen=True) -class ComparisonTarget(CSTNode): +class ComparisonTarget(CSTNode, _PrecedenceNode): """ A target for a :class:`Comparison`. Owns the comparison operator and the value to the right of the operator. @@ -1073,7 +1127,11 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ComparisonTarget def _codegen_impl(self, state: CodegenState) -> None: self.operator._codegen(state) - self.comparator._codegen(state) + self._parenthesize_child(self.comparator, state) + + @property + def _precedence(self) -> _Precedence: + return _Precedence.Comparison @add_slots @@ -1160,7 +1218,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Comparison": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.left._codegen(state) + self._parenthesize_child(self.left, state) for comp in self.comparisons: comp._codegen(state) @@ -1224,7 +1282,11 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.operator._codegen(state) - self.expression._codegen(state) + self._parenthesize_child(self.expression, state) + + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.operator.__class__.__name__) @add_slots @@ -1275,9 +1337,19 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.left._codegen(state) - self.operator._codegen(state) - self.right._codegen(state) + # Need to special case Power because it binds strangely + if isinstance(self.operator, Power): + self._parenthesize_child(self.left, state, wrap_ties=True) + self.operator._codegen(state) + self._parenthesize_child(self.right, state) + else: + self._parenthesize_child(self.left, state) + self.operator._codegen(state) + self._parenthesize_child(self.right, state, wrap_ties=True) + + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.operator.__class__.__name__) @add_slots @@ -1347,9 +1419,13 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.left._codegen(state) + self._parenthesize_child(self.left, state) self.operator._codegen(state) - self.right._codegen(state) + self._parenthesize_child(self.right, state, wrap_ties=True) + + @property + def _precedence(self) -> _Precedence: + return getattr(_Precedence, self.operator.__class__.__name__) @add_slots @@ -1404,7 +1480,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.value._codegen(state) + self._parenthesize_child(self.value, state) self.dot._codegen(state) self.attr._codegen(state) @@ -1578,7 +1654,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.value._codegen(state) + self._parenthesize_child(self.value, state) self.whitespace_after_value._codegen(state) self.lbracket._codegen(state) lastslice = len(self.slice) - 1 @@ -2101,7 +2177,7 @@ def _codegen_impl(self, state: CodegenState) -> None: whitespace_after_lambda._codegen(state) self.params._codegen(state) self.colon._codegen(state) - self.body._codegen(state) + self._parenthesize_child(self.body, state) @add_slots @@ -2346,7 +2422,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Call": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.func._codegen(state) + self._parenthesize_child(self.func, state) self.whitespace_after_func._codegen(state) state.add_token("(") self.whitespace_before_args._codegen(state) @@ -2397,7 +2473,7 @@ def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token("await") self.whitespace_after_await._codegen(state) - self.expression._codegen(state) + self._parenthesize_child(self.expression, state) @add_slots @@ -2490,15 +2566,15 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "IfExp": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.body._codegen(state) + self._parenthesize_child(self.body, state) self.whitespace_before_if._codegen(state) state.add_token("if") self.whitespace_after_if._codegen(state) - self.test._codegen(state) + self._parenthesize_child(self.test, state, wrap_ties=True) self.whitespace_before_else._codegen(state) state.add_token("else") self.whitespace_after_else._codegen(state) - self.orelse._codegen(state) + self._parenthesize_child(self.orelse, state, wrap_ties=True) @add_slots @@ -2904,7 +2980,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: isinstance(last_element.comma, Comma) or ( isinstance(last_element, StarredElement) - and len(last_element.rpar) > 0 + and last_element._is_parenthesized ) or last_element.value._safe_to_use_with_word_operator(position) ) @@ -2920,7 +2996,7 @@ def _validate(self) -> None: super(Tuple, self)._validate() if len(self.elements) == 0: - if len(self.lpar) == 0: # assumes len(lpar) == len(rpar), via superclass + if not self._is_parenthesized: raise CSTValidationError( "A zero-length tuple must be wrapped in parentheses." ) @@ -3696,8 +3772,8 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NamedExpr": def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): - self.target._codegen(state) + self._parenthesize_child(self.target, state) self.whitespace_before_walrus._codegen(state) state.add_token(":=") self.whitespace_after_walrus._codegen(state) - self.value._codegen(state) + self._parenthesize_child(self.value, state, wrap_ties=True) diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 6a831b85b..f3c3c7cb8 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -1453,7 +1453,7 @@ class Decorator(CSTNode): def _validate(self) -> None: decorator = self.decorator - if len(decorator.lpar) > 0 or len(decorator.rpar) > 0: + if decorator._is_parenthesized: raise CSTValidationError( "Cannot have parens around decorator in a Decorator." ) @@ -1579,7 +1579,7 @@ class FunctionDef(BaseCompoundStatement): whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("") def _validate(self) -> None: - if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + if self.name._is_parenthesized: raise CSTValidationError("Cannot have parens around Name in a FunctionDef.") if self.whitespace_after_def.empty: raise CSTValidationError( @@ -1709,7 +1709,7 @@ def _validate_whitespace(self) -> None: ) def _validate_parens(self) -> None: - if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + if self.name._is_parenthesized: raise CSTValidationError("Cannot have parens around Name in a ClassDef.") if isinstance(self.lpar, MaybeSentinel) and isinstance(self.rpar, RightParen): raise CSTValidationError( @@ -2283,7 +2283,7 @@ class NameItem(CSTNode): def _validate(self) -> None: # No parens around names here - if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + if self.name._is_parenthesized: raise CSTValidationError("Cannot have parens around names in NameItem.") def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NameItem": diff --git a/libcst/_nodes/tests/test_binary_op.py b/libcst/_nodes/tests/test_binary_op.py index 50f8ff79e..1632d8d53 100644 --- a/libcst/_nodes/tests/test_binary_op.py +++ b/libcst/_nodes/tests/test_binary_op.py @@ -141,6 +141,36 @@ class BinaryOperationTest(CSTNodeTest): "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 13)), }, + # Make sure operators are left associative + { + "node": cst.BinaryOperation( + left=cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Add(), + right=cst.Name("bar"), + ), + operator=cst.Add(), + right=cst.Name("baz"), + ), + "code": "foo + bar + baz", + "parser": parse_expression, + "expected_position": None, + }, + # Except for Power which is right associative + { + "node": cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Power(), + right=cst.BinaryOperation( + left=cst.Name("bar"), + operator=cst.Power(), + right=cst.Name("baz"), + ), + ), + "code": "foo ** bar ** baz", + "parser": parse_expression, + "expected_position": None, + }, ) ) def test_valid(self, **kwargs: Any) -> None: @@ -174,3 +204,36 @@ def test_valid(self, **kwargs: Any) -> None: ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) + + @data_provider( + ( + # Make sure operands are implicitly parenthesized + { + "node": cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Multiply(), + right=cst.BinaryOperation( + left=cst.Name("bar"), + operator=cst.Multiply(), + right=cst.Name("baz"), + ), + ), + "code": "foo * (bar * baz)", + }, + # But only when necessary + { + "node": cst.BinaryOperation( + left=cst.Name("foo"), + operator=cst.Add(), + right=cst.BinaryOperation( + left=cst.Name("bar"), + operator=cst.Multiply(), + right=cst.Name("baz"), + ), + ), + "code": "foo + bar * baz", + }, + ) + ) + def test_implicit_parens(self, node: cst.CSTNode, code: str) -> None: + self.assertEqual(cst.Module([]).code_for_node(node), code) diff --git a/libcst/tests/test_fuzz.py b/libcst/tests/test_fuzz.py index 590449c7b..06fe306b1 100644 --- a/libcst/tests/test_fuzz.py +++ b/libcst/tests/test_fuzz.py @@ -14,11 +14,25 @@ import os import unittest from datetime import timedelta +from typing import cast import hypothesis from hypothesmith import from_grammar import libcst +from libcst import CSTTransformer + + +class ParensRemover(CSTTransformer): + def on_leave( + self, original_node: libcst.CSTNode, updated_node: libcst.CSTNode + ) -> libcst.CSTNode: + if isinstance(updated_node, libcst.BaseExpression): + return updated_node.with_changes( + lpar=(), + rpar=(), + ) + return updated_node # If in doubt, you should use these "unit test" settings. They tune the timeouts @@ -92,6 +106,7 @@ def test_parsing_compilable_expression_strings(self, source_code: str) -> None: self.verify_identical_asts( source_code, libcst.Module([]).code_for_node(tree), mode="eval" ) + self.verify_round_trip_without_parens(tree) except libcst.ParserSyntaxError: # Unlike statements, which allow us to strip trailing whitespace, # expressions require no whitespace or newlines. Its much more work @@ -179,6 +194,19 @@ def reject_unsupported_code(source_code: str) -> None: # round trip perfectly, reject this. hypothesis.reject() + def verify_round_trip_without_parens( + self, original_node: libcst.BaseExpression + ) -> None: + """ + Verifies that removing parens from an expression does not change the + code. E.g. `(1+2)*3` with the parens removed does not become `1+2*3`. + """ + # Technically could return RemoveFromParent but we know it wont + node = cast(libcst.BaseExpression, original_node.visit(ParensRemover())) + new_code = libcst.Module([]).code_for_node(node) + new_node = libcst.parse_expression(new_code) + self.assertTrue(node.deep_equals(new_node)) + if __name__ == "__main__": hypothesis.settings.load_profile("settings-for-fuzzing")