From 266bc4fd2c4372d4dcd19f008fe19d03bbaae061 Mon Sep 17 00:00:00 2001 From: jbusecke Date: Tue, 5 Sep 2023 18:34:22 -0400 Subject: [PATCH 1/4] Add failing test for injection test with chained function --- .../callable-args-injection-chained/original.py | 5 +++++ .../callable-args-injection-chained/params.py | 4 ++++ .../callable-args-injection-chained/rewritten.py | 9 +++++++++ 3 files changed, 18 insertions(+) create mode 100644 tests/rewriter-tests/callable-args-injection-chained/original.py create mode 100644 tests/rewriter-tests/callable-args-injection-chained/params.py create mode 100644 tests/rewriter-tests/callable-args-injection-chained/rewritten.py diff --git a/tests/rewriter-tests/callable-args-injection-chained/original.py b/tests/rewriter-tests/callable-args-injection-chained/original.py new file mode 100644 index 00000000..ef1fbc13 --- /dev/null +++ b/tests/rewriter-tests/callable-args-injection-chained/original.py @@ -0,0 +1,5 @@ +def some_callable(some_argument): + pass + + +some_callable().some_func() diff --git a/tests/rewriter-tests/callable-args-injection-chained/params.py b/tests/rewriter-tests/callable-args-injection-chained/params.py new file mode 100644 index 00000000..8b50eacc --- /dev/null +++ b/tests/rewriter-tests/callable-args-injection-chained/params.py @@ -0,0 +1,4 @@ +# Parameters to be passed to RecipeRewriter constructor +params = dict( + prune=False, callable_args_injections={"some_callable": {"some_argument": 42}} +) diff --git a/tests/rewriter-tests/callable-args-injection-chained/rewritten.py b/tests/rewriter-tests/callable-args-injection-chained/rewritten.py new file mode 100644 index 00000000..3ab131e3 --- /dev/null +++ b/tests/rewriter-tests/callable-args-injection-chained/rewritten.py @@ -0,0 +1,9 @@ +def some_callable(some_argument): + pass + + +some_callable( + some_argument=_CALLABLE_ARGS_INJECTIONS.get("some_callable", {}).get( # noqa + "some_argument" + ) +).some_func() From d661e824e001f2ca68edfd4898b168979a9929d1 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 5 Sep 2023 17:34:31 -0700 Subject: [PATCH 2/4] get tests to pass --- pangeo_forge_runner/recipe_rewriter.py | 60 ++++++++++++------- .../original.py | 2 +- .../rewritten.py | 2 +- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/pangeo_forge_runner/recipe_rewriter.py b/pangeo_forge_runner/recipe_rewriter.py index aa3b3817..ce61a5ef 100644 --- a/pangeo_forge_runner/recipe_rewriter.py +++ b/pangeo_forge_runner/recipe_rewriter.py @@ -20,7 +20,9 @@ fix_missing_locations, keyword, ) -from typing import Optional +from typing import Optional, TypeVar + +N = TypeVar("N") class RecipeRewriter(NodeTransformer): @@ -114,20 +116,44 @@ def _make_injected_get( keywords=[], ) + def inject_keywords(self, node: N) -> N: + """ """ + for name, params in self.callable_args_injections.items(): + if hasattr(node.func, "id") and name == node.func.id: + node.keywords += [ + keyword( + arg=k, + value=self._make_injected_get( + "_CALLABLE_ARGS_INJECTIONS", name, k + ), + ) + for k in params + ] + + elif hasattr(node.func, "value") and name == node.func.value.func.id: + node.func.value.keywords += [ + keyword( + arg=k, + value=self._make_injected_get( + "_CALLABLE_ARGS_INJECTIONS", name, k + ), + ) + for k in params + ] + return node + def visit_Call(self, node: Call) -> Call: """ Rewrite calls that return a FilePattern if we need to prune them """ if isinstance(node.func, Attribute): - # FIXME: Support it being imported as from apache_beam import Create too - if "apache_beam" not in self._import_aliases.values(): + if ( + # FIXME: Support it being imported as from apache_beam import Create too # if beam hasn't been imported, don't rewrite anything - return node - - # Only rewrite parameters to apache_beam.Create, regardless - # of how it is imported as - if node.func.attr == "Create" and ( - self._import_aliases.get(node.func.value.id) == "apache_beam" + "apache_beam" in self._import_aliases.values() + # Rewrite parameters to apache_beam.Create, regardless of how it is imported + and node.func.attr == "Create" + and self._import_aliases.get(node.func.value.id) == "apache_beam" ): # If there is a single argument pased to beam.Create, and it is .items() # This is the heurestic we use for figuring out that we are in fact operating on a FilePattern object @@ -137,19 +163,11 @@ def visit_Call(self, node: Call) -> Call: and node.args[0].func.attr == "items" ): return fix_missing_locations(self.transform_prune(node)) + elif node.func.attr == "with_resource_hints": + return fix_missing_locations(self.inject_keywords(node)) + elif isinstance(node.func, Name): # FIXME: Support importing in other ways - for name, params in self.callable_args_injections.items(): - if name == node.func.id: - node.keywords += [ - keyword( - arg=k, - value=self._make_injected_get( - "_CALLABLE_ARGS_INJECTIONS", name, k - ), - ) - for k in params - ] - return fix_missing_locations(node) + return fix_missing_locations(self.inject_keywords(node)) return node diff --git a/tests/rewriter-tests/callable-args-injection-chained/original.py b/tests/rewriter-tests/callable-args-injection-chained/original.py index ef1fbc13..d6902534 100644 --- a/tests/rewriter-tests/callable-args-injection-chained/original.py +++ b/tests/rewriter-tests/callable-args-injection-chained/original.py @@ -2,4 +2,4 @@ def some_callable(some_argument): pass -some_callable().some_func() +some_callable().with_resource_hints() diff --git a/tests/rewriter-tests/callable-args-injection-chained/rewritten.py b/tests/rewriter-tests/callable-args-injection-chained/rewritten.py index 3ab131e3..95db9ad6 100644 --- a/tests/rewriter-tests/callable-args-injection-chained/rewritten.py +++ b/tests/rewriter-tests/callable-args-injection-chained/rewritten.py @@ -6,4 +6,4 @@ def some_callable(some_argument): some_argument=_CALLABLE_ARGS_INJECTIONS.get("some_callable", {}).get( # noqa "some_argument" ) -).some_func() +).with_resource_hints() From 0e6707dc640c05876214e4df59c41a926387ca79 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 5 Sep 2023 17:36:38 -0700 Subject: [PATCH 3/4] fix typing --- pangeo_forge_runner/recipe_rewriter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pangeo_forge_runner/recipe_rewriter.py b/pangeo_forge_runner/recipe_rewriter.py index ce61a5ef..bab6dad6 100644 --- a/pangeo_forge_runner/recipe_rewriter.py +++ b/pangeo_forge_runner/recipe_rewriter.py @@ -20,9 +20,7 @@ fix_missing_locations, keyword, ) -from typing import Optional, TypeVar - -N = TypeVar("N") +from typing import Optional class RecipeRewriter(NodeTransformer): @@ -116,7 +114,7 @@ def _make_injected_get( keywords=[], ) - def inject_keywords(self, node: N) -> N: + def inject_keywords(self, node: Call) -> Call: """ """ for name, params in self.callable_args_injections.items(): if hasattr(node.func, "id") and name == node.func.id: From 4ea3f880f63d8da3a1ff8546c5b73571164a1872 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 5 Sep 2023 17:40:46 -0700 Subject: [PATCH 4/4] add comments --- pangeo_forge_runner/recipe_rewriter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pangeo_forge_runner/recipe_rewriter.py b/pangeo_forge_runner/recipe_rewriter.py index bab6dad6..5f6eb924 100644 --- a/pangeo_forge_runner/recipe_rewriter.py +++ b/pangeo_forge_runner/recipe_rewriter.py @@ -115,9 +115,10 @@ def _make_injected_get( ) def inject_keywords(self, node: Call) -> Call: - """ """ + """Inject keywords into calls.""" for name, params in self.callable_args_injections.items(): if hasattr(node.func, "id") and name == node.func.id: + # this is a non-chained call, so append to top-level `.keywords` node.keywords += [ keyword( arg=k, @@ -129,6 +130,7 @@ def inject_keywords(self, node: Call) -> Call: ] elif hasattr(node.func, "value") and name == node.func.value.func.id: + # this is a *chained* call, so append to `.func.value.keywords` node.func.value.keywords += [ keyword( arg=k,