Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traitlet injection fails with chained functions (e.g. beam resource hints) #102

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions pangeo_forge_runner/recipe_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,46 @@
keywords=[],
)

def inject_keywords(self, node: Call) -> Call:
"""Inject keywords into calls."""
for name, params in self.callable_args_injections.items():
if hasattr(node.func, "id") and name == node.func.id:
# this is a non-chained call, so append to top-level `.keywords`
node.keywords += [
keyword(
arg=k,
value=self._make_injected_get(
"_CALLABLE_ARGS_INJECTIONS", name, k
),
)
for k in params
]

elif hasattr(node.func, "value") and name == node.func.value.func.id:
# this is a *chained* call, so append to `.func.value.keywords`
node.func.value.keywords += [

Check warning on line 134 in pangeo_forge_runner/recipe_rewriter.py

View check run for this annotation

Codecov / codecov/patch

pangeo_forge_runner/recipe_rewriter.py#L134

Added line #L134 was not covered by tests
keyword(
arg=k,
value=self._make_injected_get(
"_CALLABLE_ARGS_INJECTIONS", name, k
),
)
for k in params
]
return node

def visit_Call(self, node: Call) -> Call:
"""
Rewrite calls that return a FilePattern if we need to prune them
"""
if isinstance(node.func, Attribute):
# FIXME: Support it being imported as from apache_beam import Create too
if "apache_beam" not in self._import_aliases.values():
if (
# FIXME: Support it being imported as from apache_beam import Create too
# if beam hasn't been imported, don't rewrite anything
return node

# Only rewrite parameters to apache_beam.Create, regardless
# of how it is imported as
if node.func.attr == "Create" and (
self._import_aliases.get(node.func.value.id) == "apache_beam"
"apache_beam" in self._import_aliases.values()
# Rewrite parameters to apache_beam.Create, regardless of how it is imported
and node.func.attr == "Create"
and self._import_aliases.get(node.func.value.id) == "apache_beam"
):
# If there is a single argument pased to beam.Create, and it is <something>.items()
# This is the heurestic we use for figuring out that we are in fact operating on a FilePattern object
Expand All @@ -137,19 +163,11 @@
and node.args[0].func.attr == "items"
):
return fix_missing_locations(self.transform_prune(node))
elif node.func.attr == "with_resource_hints":
return fix_missing_locations(self.inject_keywords(node))

Check warning on line 167 in pangeo_forge_runner/recipe_rewriter.py

View check run for this annotation

Codecov / codecov/patch

pangeo_forge_runner/recipe_rewriter.py#L167

Added line #L167 was not covered by tests

elif isinstance(node.func, Name):
# FIXME: Support importing in other ways
for name, params in self.callable_args_injections.items():
if name == node.func.id:
node.keywords += [
keyword(
arg=k,
value=self._make_injected_get(
"_CALLABLE_ARGS_INJECTIONS", name, k
),
)
for k in params
]
return fix_missing_locations(node)
return fix_missing_locations(self.inject_keywords(node))

return node
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def some_callable(some_argument):
pass


some_callable().with_resource_hints()
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Parameters to be passed to RecipeRewriter constructor
params = dict(
prune=False, callable_args_injections={"some_callable": {"some_argument": 42}}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def some_callable(some_argument):
pass


some_callable(
some_argument=_CALLABLE_ARGS_INJECTIONS.get("some_callable", {}).get( # noqa
"some_argument"
)
).with_resource_hints()
Loading