Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 341 #11

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 104 additions & 28 deletions libcst/_nodes/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ast import literal_eval
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum, auto
from enum import Enum, IntEnum, auto
from tokenize import (
Floatnumber as FLOATNUMBER_RE,
Imagnumber as IMAGNUMBER_RE,
Expand Down Expand Up @@ -43,11 +43,36 @@
IsNot,
Not,
NotIn,
Power,
)
from libcst._nodes.whitespace import BaseParenthesizableWhitespace, SimpleWhitespace
from libcst._visitors import CSTVisitorT


class _Precedence(IntEnum):
NamedExpr = Yield = auto()
Lambda = auto()
IfExp = auto()
Or = auto()
And = auto()
Not = auto()
Comparison = auto()
BitOr = auto()
BitXor = auto()
BitAnd = auto()
LeftShift = RightShift = auto()
Add = Subtract = auto()
Multiply = MatrixMultiply = Divide = FloorDivide = Modulo = auto()
# The docs claim that Power has higher precedence than unary operators
# this is innacurate. Power binds more tightly on the right than unary
# operators and less tightly on the left. In other words Power is right
# associative with same precedence as the unary operators.
Plus = Minus = BitInvert = Power = auto()
Await = auto()
Attribute = Subscript = Call = auto()
Atom = auto()


@add_slots
@dataclass(frozen=True)
class LeftSquareBracket(CSTNode):
Expand Down Expand Up @@ -234,6 +259,10 @@ def _validate(self) -> None:
if len(self.lpar) != len(self.rpar):
raise CSTValidationError("Cannot have unbalanced parens.")

@property
def _is_parenthesized(self) -> bool:
return len(self.lpar) > 0 and len(self.rpar) > 0

@contextmanager
def _parenthesize(self, state: CodegenState) -> Generator[None, None, None]:
for lpar in self.lpar:
Expand All @@ -249,7 +278,35 @@ class ExpressionPosition(Enum):
RIGHT = auto()


class BaseExpression(_BaseParenthesizedNode, ABC):
class _PrecedenceNode:
"""
A base class for nodes which may need to parenthesize their children
based on operator precedence.
"""

@property
def _precedence(self) -> _Precedence:
return getattr(_Precedence, self.__class__.__name__, _Precedence.Atom)

def _parenthesize_child(
self, child: "BaseExpression", state: CodegenState, wrap_ties: bool = False
) -> None:
"""
A helper function that parenthesizes child if has lower precedence than self.
"""
if not child._is_parenthesized:
child_p = child._precedence
self_p = self._precedence

if (child_p <= self_p) if wrap_ties else (child_p < self_p):
LeftParen()._codegen(state)
child._codegen(state)
RightParen()._codegen(state)
return
child._codegen(state)


class BaseExpression(_BaseParenthesizedNode, _PrecedenceNode, ABC):
"""
An base class for all expressions. :class:`BaseExpression` contains no fields.
"""
Expand All @@ -263,8 +320,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
dictionaries and lists will override this to signifiy that they're always
safe.
"""

return len(self.lpar) > 0 and len(self.rpar) > 0
return self._is_parenthesized

def _check_left_right_word_concatenation_safety(
self,
Expand Down Expand Up @@ -983,9 +1039,7 @@ def _validate(self) -> None:
super(ConcatenatedString, self)._validate()

# Strings that are concatenated cannot have parens.
if bool(self.left.lpar) or bool(self.left.rpar):
raise CSTValidationError("Cannot concatenate parenthesized strings.")
if bool(self.right.lpar) or bool(self.right.rpar):
if self.left._is_parenthesized or self.right._is_parenthesized:
raise CSTValidationError("Cannot concatenate parenthesized strings.")

# Cannot concatenate str and bytes
Expand Down Expand Up @@ -1039,7 +1093,7 @@ def evaluated_value(self) -> Optional[str]:

@add_slots
@dataclass(frozen=True)
class ComparisonTarget(CSTNode):
class ComparisonTarget(CSTNode, _PrecedenceNode):
"""
A target for a :class:`Comparison`. Owns the comparison operator and the value to
the right of the operator.
Expand Down Expand Up @@ -1073,7 +1127,11 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ComparisonTarget

def _codegen_impl(self, state: CodegenState) -> None:
self.operator._codegen(state)
self.comparator._codegen(state)
self._parenthesize_child(self.comparator, state)

@property
def _precedence(self) -> _Precedence:
return _Precedence.Comparison


@add_slots
Expand Down Expand Up @@ -1160,7 +1218,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Comparison":

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.left._codegen(state)
self._parenthesize_child(self.left, state)
for comp in self.comparisons:
comp._codegen(state)

Expand Down Expand Up @@ -1224,7 +1282,11 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.operator._codegen(state)
self.expression._codegen(state)
self._parenthesize_child(self.expression, state)

@property
def _precedence(self) -> _Precedence:
return getattr(_Precedence, self.operator.__class__.__name__)


@add_slots
Expand Down Expand Up @@ -1275,9 +1337,19 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.left._codegen(state)
self.operator._codegen(state)
self.right._codegen(state)
# Need to special case Power because it binds strangely
if isinstance(self.operator, Power):
self._parenthesize_child(self.left, state, wrap_ties=True)
self.operator._codegen(state)
self._parenthesize_child(self.right, state)
else:
self._parenthesize_child(self.left, state)
self.operator._codegen(state)
self._parenthesize_child(self.right, state, wrap_ties=True)

@property
def _precedence(self) -> _Precedence:
return getattr(_Precedence, self.operator.__class__.__name__)


@add_slots
Expand Down Expand Up @@ -1347,9 +1419,13 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.left._codegen(state)
self._parenthesize_child(self.left, state)
self.operator._codegen(state)
self.right._codegen(state)
self._parenthesize_child(self.right, state, wrap_ties=True)

@property
def _precedence(self) -> _Precedence:
return getattr(_Precedence, self.operator.__class__.__name__)


@add_slots
Expand Down Expand Up @@ -1404,7 +1480,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.value._codegen(state)
self._parenthesize_child(self.value, state)
self.dot._codegen(state)
self.attr._codegen(state)

Expand Down Expand Up @@ -1578,7 +1654,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.value._codegen(state)
self._parenthesize_child(self.value, state)
self.whitespace_after_value._codegen(state)
self.lbracket._codegen(state)
lastslice = len(self.slice) - 1
Expand Down Expand Up @@ -2101,7 +2177,7 @@ def _codegen_impl(self, state: CodegenState) -> None:
whitespace_after_lambda._codegen(state)
self.params._codegen(state)
self.colon._codegen(state)
self.body._codegen(state)
self._parenthesize_child(self.body, state)


@add_slots
Expand Down Expand Up @@ -2346,7 +2422,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Call":

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.func._codegen(state)
self._parenthesize_child(self.func, state)
self.whitespace_after_func._codegen(state)
state.add_token("(")
self.whitespace_before_args._codegen(state)
Expand Down Expand Up @@ -2397,7 +2473,7 @@ def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token("await")
self.whitespace_after_await._codegen(state)
self.expression._codegen(state)
self._parenthesize_child(self.expression, state)


@add_slots
Expand Down Expand Up @@ -2490,15 +2566,15 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "IfExp":

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.body._codegen(state)
self._parenthesize_child(self.body, state)
self.whitespace_before_if._codegen(state)
state.add_token("if")
self.whitespace_after_if._codegen(state)
self.test._codegen(state)
self._parenthesize_child(self.test, state, wrap_ties=True)
self.whitespace_before_else._codegen(state)
state.add_token("else")
self.whitespace_after_else._codegen(state)
self.orelse._codegen(state)
self._parenthesize_child(self.orelse, state, wrap_ties=True)


@add_slots
Expand Down Expand Up @@ -2904,7 +2980,7 @@ def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
isinstance(last_element.comma, Comma)
or (
isinstance(last_element, StarredElement)
and len(last_element.rpar) > 0
and last_element._is_parenthesized
)
or last_element.value._safe_to_use_with_word_operator(position)
)
Expand All @@ -2920,7 +2996,7 @@ def _validate(self) -> None:
super(Tuple, self)._validate()

if len(self.elements) == 0:
if len(self.lpar) == 0: # assumes len(lpar) == len(rpar), via superclass
if not self._is_parenthesized:
raise CSTValidationError(
"A zero-length tuple must be wrapped in parentheses."
)
Expand Down Expand Up @@ -3696,8 +3772,8 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NamedExpr":

def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.target._codegen(state)
self._parenthesize_child(self.target, state)
self.whitespace_before_walrus._codegen(state)
state.add_token(":=")
self.whitespace_after_walrus._codegen(state)
self.value._codegen(state)
self._parenthesize_child(self.value, state, wrap_ties=True)
8 changes: 4 additions & 4 deletions libcst/_nodes/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,7 +1453,7 @@ class Decorator(CSTNode):

def _validate(self) -> None:
decorator = self.decorator
if len(decorator.lpar) > 0 or len(decorator.rpar) > 0:
if decorator._is_parenthesized:
raise CSTValidationError(
"Cannot have parens around decorator in a Decorator."
)
Expand Down Expand Up @@ -1579,7 +1579,7 @@ class FunctionDef(BaseCompoundStatement):
whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("")

def _validate(self) -> None:
if len(self.name.lpar) > 0 or len(self.name.rpar) > 0:
if self.name._is_parenthesized:
raise CSTValidationError("Cannot have parens around Name in a FunctionDef.")
if self.whitespace_after_def.empty:
raise CSTValidationError(
Expand Down Expand Up @@ -1709,7 +1709,7 @@ def _validate_whitespace(self) -> None:
)

def _validate_parens(self) -> None:
if len(self.name.lpar) > 0 or len(self.name.rpar) > 0:
if self.name._is_parenthesized:
raise CSTValidationError("Cannot have parens around Name in a ClassDef.")
if isinstance(self.lpar, MaybeSentinel) and isinstance(self.rpar, RightParen):
raise CSTValidationError(
Expand Down Expand Up @@ -2283,7 +2283,7 @@ class NameItem(CSTNode):

def _validate(self) -> None:
# No parens around names here
if len(self.name.lpar) > 0 or len(self.name.rpar) > 0:
if self.name._is_parenthesized:
raise CSTValidationError("Cannot have parens around names in NameItem.")

def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NameItem":
Expand Down
Loading