Skip to content

Commit

Permalink
Add example visit method to SpacesVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Nov 28, 2024
1 parent 6cfab68 commit 5402378
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
7 changes: 6 additions & 1 deletion rewrite/rewrite/python/format/auto_format.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from rewrite import Recipe, Tree, Cursor
from rewrite.java import JavaSourceFile
from rewrite.java import JavaSourceFile, MethodDeclaration, J, Space
from rewrite.python import PythonVisitor, SpacesStyle, IntelliJ
from rewrite.visitor import P, T

Expand All @@ -27,3 +27,8 @@ class SpacesVisitor(PythonVisitor):
def __init__(self, style: SpacesStyle, stop_after: Tree = None):
self._style = style
self._stop_after = stop_after

def visit_method_declaration(self, method_declaration: MethodDeclaration, p: P) -> J:
return method_declaration.padding.with_parameters(
method_declaration.padding.parameters.with_before(Space.SINGLE_SPACE if self._style.beforeParentheses.method_parentheses else Space.EMPTY)
)
14 changes: 12 additions & 2 deletions rewrite/rewrite/python/parser.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import ast
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional

from rewrite import Parser, ParserInput, ExecutionContext, SourceFile, ParseError
from rewrite import Parser, ParserInput, ExecutionContext, SourceFile, ParseError, NamedStyles, Markers, Tree, random_id
from rewrite.parser import require_print_equals_input, ParserBuilder
from ._parser_visitor import ParserVisitor
from .tree import CompilationUnit

logging.basicConfig(level=logging.ERROR)


@dataclass(frozen=True)
class PythonParser(Parser):
_styles: Optional[Iterable[NamedStyles]]

def parse_inputs(self, sources: Iterable[ParserInput], relative_to: Optional[Path],
ctx: ExecutionContext) -> Iterable[SourceFile]:
accepted = (source for source in sources if self.accept(source.path))
Expand All @@ -20,6 +24,7 @@ def parse_inputs(self, sources: Iterable[ParserInput], relative_to: Optional[Pat
source_str = source.source().read()
tree = ast.parse(source_str, source.path)
cu = ParserVisitor(source_str).visit(tree).with_source_path(source.path)
cu = cu.with_markers(Markers.build(random_id(), self._styles)) if self._styles else cu
cu = require_print_equals_input(self, cu, source, relative_to, ctx)
except Exception as e:
logging.error(f"An error was encountered while parsing {source.path}: {str(e)}", exc_info=True)
Expand All @@ -37,6 +42,11 @@ class PythonParserBuilder(ParserBuilder):
def __init__(self):
self._source_file_type = type(CompilationUnit)
self._dsl_name = 'python'
self._styles = None

def styles(self, styles: Iterable[NamedStyles]):
self._styles = styles
return self

def build(self) -> Parser:
return PythonParser()
return PythonParser(self._styles)
2 changes: 1 addition & 1 deletion rewrite/rewrite/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def first_enclosing(self, type: Type[P]) -> P:
return None

def fork(self) -> Cursor:
return Cursor(self.parent.fork(), self.value)
return Cursor(self.parent.fork(), self.value) if self.parent else self


class TreeVisitor(Protocol[T, P]):
Expand Down
21 changes: 20 additions & 1 deletion rewrite/tests/python/all/format/demo_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from rewrite.java import Space, P
from rewrite.python import PythonVisitor
from rewrite.python import PythonVisitor, AutoFormat
from rewrite.test import rewrite_run, python, from_visitor


Expand All @@ -18,6 +18,25 @@ def getter(self, row):
recipe=from_visitor(NoSpaces())
)


def test_spaces_before_method_parentheses():
rewrite_run(
# language=python
python(
"""
class Foo:
def getter (self, row):
pass
""",
"""
class Foo:
def getter(self, row):
pass
"""
),
recipe=AutoFormat()
)

class NoSpaces(PythonVisitor):
def visit_space(self, space: Optional[Space], loc: Optional[Space.Location], p: P) -> Optional[Space]:
return Space.EMPTY if space else None

0 comments on commit 5402378

Please sign in to comment.