From 3115185039996de3cf85f8d1a0a287f18acff8f6 Mon Sep 17 00:00:00 2001 From: Dan D'Avella Date: Tue, 23 Jul 2024 14:51:26 -0400 Subject: [PATCH] Enable customizable call modifier logic (#736) --- .../codemods/import_modifier_codemod.py | 7 +++---- .../codemods/imported_call_modifier.py | 19 ++++--------------- src/codemodder/codemods/utils_mixin.py | 2 +- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/codemodder/codemods/import_modifier_codemod.py b/src/codemodder/codemods/import_modifier_codemod.py index a2ffe248..b43f1082 100644 --- a/src/codemodder/codemods/import_modifier_codemod.py +++ b/src/codemodder/codemods/import_modifier_codemod.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Callable, Mapping +from typing import Mapping import libcst as cst from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor @@ -42,7 +42,7 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args): class ImportModifierCodemod(LibcstResultTransformer, metaclass=ABCMeta): - result_filter: Callable[[cst.CSTNode], bool] | None = None + call_modifier: type[MappingImportedCallModifier] = MappingImportedCallModifier @property def dependency(self) -> Dependency | None: @@ -54,13 +54,12 @@ def mapping(self) -> Mapping[str, str]: pass def transform_module_impl(self, tree: cst.Module) -> cst.Module: - visitor = MappingImportedCallModifier( + visitor = self.call_modifier( self.context, self.file_context, self.mapping, self.change_description, self.results, - self.result_filter, ) result_tree = visitor.transform_module(tree) self.file_context.codemod_changes.extend(visitor.changes_in_file) diff --git a/src/codemodder/codemods/imported_call_modifier.py b/src/codemodder/codemods/imported_call_modifier.py index 28dc66cf..53925155 100644 --- a/src/codemodder/codemods/imported_call_modifier.py +++ b/src/codemodder/codemods/imported_call_modifier.py @@ -1,11 +1,10 @@ import abc -from typing import Callable, Generic, Mapping, Sequence, Set, TypeVar, Union +from typing import Generic, Mapping, Sequence, Set, TypeVar, Union import libcst as cst from libcst import matchers from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst.metadata import PositionProvider -from typing_extensions import override +from libcst.metadata import ParentNodeProvider, PositionProvider from codemodder.codemods.base_visitor import UtilsMixin from codemodder.codemods.utils_mixin import NameResolutionMixin @@ -24,7 +23,7 @@ class ImportedCallModifier( UtilsMixin, metaclass=abc.ABCMeta, ): - METADATA_DEPENDENCIES = (PositionProvider,) + METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) def __init__( self, @@ -33,7 +32,6 @@ def __init__( matching_functions: FunctionMatchType, change_description: str, results: list[Result] | None = None, - result_filter: Callable[[cst.CSTNode], bool] | None = None, ): VisitorBasedCodemodCommand.__init__(self, codemod_context) self.line_exclude = file_context.line_exclude @@ -43,15 +41,6 @@ def __init__( self.changes_in_file: list[Change] = [] self.results = results self.file_context = file_context - self.result_filter = result_filter - - @override - def filter_by_result(self, node: cst.CSTNode) -> bool: - return ( - self.result_filter(node) - if self.result_filter - else super().filter_by_result(node) - ) def updated_args(self, original_args: Sequence[cst.Arg]): return original_args @@ -82,7 +71,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): if self.filter_by_path_includes_or_excludes(pos_to_match): true_name = self.find_base_name(original_node.func) if ( - self._is_direct_call_from_imported_module(original_node) + self.is_direct_call_from_imported_module(original_node) and true_name and true_name in self.matching_functions ): diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 5d253c76..a6bce24c 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -100,7 +100,7 @@ def base_name_for_import(self, import_node, import_alias): # it is a from import return _get_name(import_node) + "." + get_full_name_for_node(import_alias.name) - def _is_direct_call_from_imported_module( + def is_direct_call_from_imported_module( self, call: cst.Call ) -> Optional[tuple[Union[cst.Import, cst.ImportFrom], cst.ImportAlias]]: for nodo in iterate_left_expressions(call):