Skip to content

Commit

Permalink
Enable customizable call modifier logic (#736)
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella authored Jul 23, 2024
1 parent 3a20d20 commit 3115185
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 20 deletions.
7 changes: 3 additions & 4 deletions src/codemodder/codemods/import_modifier_codemod.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Callable, Mapping
from typing import Mapping

import libcst as cst
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
Expand Down Expand Up @@ -42,7 +42,7 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args):


class ImportModifierCodemod(LibcstResultTransformer, metaclass=ABCMeta):
result_filter: Callable[[cst.CSTNode], bool] | None = None
call_modifier: type[MappingImportedCallModifier] = MappingImportedCallModifier

@property
def dependency(self) -> Dependency | None:
Expand All @@ -54,13 +54,12 @@ def mapping(self) -> Mapping[str, str]:
pass

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
visitor = MappingImportedCallModifier(
visitor = self.call_modifier(
self.context,
self.file_context,
self.mapping,
self.change_description,
self.results,
self.result_filter,
)
result_tree = visitor.transform_module(tree)
self.file_context.codemod_changes.extend(visitor.changes_in_file)
Expand Down
19 changes: 4 additions & 15 deletions src/codemodder/codemods/imported_call_modifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import abc
from typing import Callable, Generic, Mapping, Sequence, Set, TypeVar, Union
from typing import Generic, Mapping, Sequence, Set, TypeVar, Union

import libcst as cst
from libcst import matchers
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import PositionProvider
from typing_extensions import override
from libcst.metadata import ParentNodeProvider, PositionProvider

from codemodder.codemods.base_visitor import UtilsMixin
from codemodder.codemods.utils_mixin import NameResolutionMixin
Expand All @@ -24,7 +23,7 @@ class ImportedCallModifier(
UtilsMixin,
metaclass=abc.ABCMeta,
):
METADATA_DEPENDENCIES = (PositionProvider,)
METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider)

def __init__(
self,
Expand All @@ -33,7 +32,6 @@ def __init__(
matching_functions: FunctionMatchType,
change_description: str,
results: list[Result] | None = None,
result_filter: Callable[[cst.CSTNode], bool] | None = None,
):
VisitorBasedCodemodCommand.__init__(self, codemod_context)
self.line_exclude = file_context.line_exclude
Expand All @@ -43,15 +41,6 @@ def __init__(
self.changes_in_file: list[Change] = []
self.results = results
self.file_context = file_context
self.result_filter = result_filter

@override
def filter_by_result(self, node: cst.CSTNode) -> bool:
return (
self.result_filter(node)
if self.result_filter
else super().filter_by_result(node)
)

def updated_args(self, original_args: Sequence[cst.Arg]):
return original_args
Expand Down Expand Up @@ -82,7 +71,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
if self.filter_by_path_includes_or_excludes(pos_to_match):
true_name = self.find_base_name(original_node.func)
if (
self._is_direct_call_from_imported_module(original_node)
self.is_direct_call_from_imported_module(original_node)
and true_name
and true_name in self.matching_functions
):
Expand Down
2 changes: 1 addition & 1 deletion src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def base_name_for_import(self, import_node, import_alias):
# it is a from import
return _get_name(import_node) + "." + get_full_name_for_node(import_alias.name)

def _is_direct_call_from_imported_module(
def is_direct_call_from_imported_module(
self, call: cst.Call
) -> Optional[tuple[Union[cst.Import, cst.ImportFrom], cst.ImportAlias]]:
for nodo in iterate_left_expressions(call):
Expand Down

0 comments on commit 3115185

Please sign in to comment.