Skip to content

Commit

Permalink
Curve cleanup (#73)
Browse files Browse the repository at this point in the history
Co-authored-by: Gerald Walter Irsiegler <gerald@Irsiegler-P14s>
  • Loading branch information
GeraldIr and Gerald Walter Irsiegler authored Oct 23, 2023
1 parent 3298375 commit 689ab92
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 66 deletions.
21 changes: 14 additions & 7 deletions openeo_pg_parser_networkx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from openeo_pg_parser_networkx.process_registry import Process
from openeo_pg_parser_networkx.utils import (
ProcessGraphUnflattener,
format_nodes,
generate_function_from_nodes,
generate_curve_fit_function,
parse_nested_parameter,
)

Expand Down Expand Up @@ -256,17 +255,25 @@ def _walk_node(self):
)
self._parse_argument(unpacked_arg, arg_name, access_func=access_func)

# This is where we resolve the callbacks, e.g. sub-process graphs that are passed as arguments to other processes.
# We handle fit and predict curve functions separately, because we don't want those process graphs to be compiled into a normal pg callable.
for arg_name, arg in self._EVAL_ENV.callbacks_to_walk.items():
if "fitcurve" in self._EVAL_ENV.node_name and arg_name == "function":
if (
any(
substring in self._EVAL_ENV.node.process_id
for substring in ['fit_curve', 'predict_curve']
)
and arg_name == "function"
):
function_pg_data = self.pg_data["process_graph"][
self._EVAL_ENV.node_name
]["arguments"][arg_name]
pg = OpenEOProcessGraph(pg_data=function_pg_data)
formatted_nodes = format_nodes(pg=pg, vars=['x'])

self.G.nodes[self._EVAL_ENV.node_uid]["resolved_kwargs"][
arg_name
] = generate_function_from_nodes(formatted_nodes)
] = generate_curve_fit_function(
process_graph=OpenEOProcessGraph(pg_data=function_pg_data),
variables=['x'],
)
else:
self.G.nodes[self._EVAL_ENV.node_uid]["resolved_kwargs"][
arg_name
Expand Down
139 changes: 83 additions & 56 deletions openeo_pg_parser_networkx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _process_value(self, value) -> Any:
from openeo_pg_parser_networkx.pg_schema import ParameterReference, ResultReference


def format_nodes(pg, vars):
def _format_nodes(pg, vars):
nodes = []

for var in vars:
Expand All @@ -156,108 +156,130 @@ def format_nodes(pg, vars):
for node_id in pg:
node = [n for n in pg.nodes if n[0] == node_id][0]

formatted_node = (node[1]["process_id"], node[1]["node_name"])

# Parameters

parameter_names = list(node[1]["resolved_kwargs"].keys())

if parameter_names[0] == "data":
formatted_node += (node[1]["resolved_kwargs"]["index"],)
formatted_node += (None,)
nodes.append(formatted_node)
continue

if "x" in parameter_names:
if isinstance(node[1]["resolved_kwargs"]["x"], ParameterReference):
formatted_node += (node[1]["resolved_kwargs"]["x"].from_parameter,)
elif isinstance(node[1]["resolved_kwargs"]["x"], ResultReference):
formatted_node += (node[1]["resolved_kwargs"]["x"].from_node,)
# Special Cases
# Array Element (indexing)
if node[1]["process_id"] == "array_element":
data = node[1]["resolved_kwargs"]["data"]
if isinstance(data, ParameterReference):
data = data.from_parameter
else:
const_name = node[1]["node_name"] + "_x"
nodes.append(
(
"const",
const_name,
node[1]["resolved_kwargs"]["x"],
node[1]["node_name"] + "_const",
node[1]["resolved_kwargs"]["data"],
None,
)
)
formatted_node += (const_name,)
else:
formatted_node += (None,)
if "y" in parameter_names:
if isinstance(node[1]["resolved_kwargs"]["y"], ParameterReference):
formatted_node += (node[1]["resolved_kwargs"]["y"].from_parameter,)
elif isinstance(node[1]["resolved_kwargs"]["y"], ResultReference):
formatted_node += (node[1]["resolved_kwargs"]["y"].from_node,)
data = node[1]["node_name"] + "_const"
nodes.append(
(
"array_element",
node[1]["node_name"],
data,
int(node[1]["resolved_kwargs"]["index"]),
)
)
continue

# Unary and Binary Operations

formatted_node = (node[1]["process_id"], node[1]["node_name"])
parameter_names = list(node[1]["resolved_kwargs"].keys())

if len(parameter_names) == 1:
parameter_names.append(None)

for parameter in parameter_names:
if parameter is None:
formatted_node += (None,)
continue
parameter = node[1]["resolved_kwargs"][parameter]
if isinstance(parameter, ParameterReference):
formatted_node += (parameter.from_parameter,)
elif isinstance(parameter, ResultReference):
formatted_node += (parameter.from_node,)
else:
const_name = node[1]["node_name"] + "_y"
const_name = node[1]["node_name"] + "_const"
nodes.append(
(
"const",
const_name,
node[1]["resolved_kwargs"]["y"],
parameter,
None,
)
)
formatted_node += (const_name,)
else:
formatted_node += (None,)

nodes.append(formatted_node)
return nodes


# Fit_Curve function builder

BIN_OP_MAPPING = {
"multiply": ast.Mult(),
"divide": ast.Div(),
"subtract": ast.Sub(),
"add": ast.Add(),
"power": ast.Pow(),
"mod": ast.Mod(),
}


def generate_function_from_nodes(nodes):
def _generate_function_from_nodes(nodes: dict):
temp_results = {}
body = []

for node_type, node_name, operand1, operand2 in nodes:
# Value Operations
if node_type == "array_element":
value = ast.Subscript(
value=ast.Name(id="parameters", ctx=ast.Load()),
slice=ast.Index(value=ast.Num(n=int(operand1))),
value=ast.Name(id=operand1, ctx=ast.Load()),
slice=ast.Index(value=ast.Num(n=int(operand2))),
ctx=ast.Load(),
)
elif node_type == "const":
value = ast.Num(n=operand1)
if isinstance(operand1, list):
value = ast.List(
elts=[ast.Num(n=n) for n in operand1],
ctx=ast.Load(),
)
else:
value = ast.Num(n=operand1)

elif node_type == "variable":
value = ast.Name(id=operand1, ctx=ast.Load())
elif node_type == "multiply":
value = ast.BinOp(
left=temp_results[operand1], op=ast.Mult(), right=temp_results[operand2]
)
elif node_type == "divide":
value = ast.BinOp(
left=temp_results[operand1], op=ast.Div(), right=temp_results[operand2]
)
elif node_type == "subtract":
value = ast.BinOp(
left=temp_results[operand1], op=ast.Sub(), right=temp_results[operand2]
)
elif node_type == "add":

# Binary Operations
elif node_type in BIN_OP_MAPPING:
value = ast.BinOp(
left=temp_results[operand1], op=ast.Add(), right=temp_results[operand2]
left=temp_results[operand1],
op=BIN_OP_MAPPING[node_type],
right=temp_results[operand2],
)

elif node_type == "cos":
# Unary Numpy Functions
elif node_type in ["cos", "sin", "tan", "sqrt", "absolute"]:
value = ast.Call(
func=ast.Attribute(
value=ast.Name(id="np", ctx=ast.Load()), attr="cos", ctx=ast.Load()
value=ast.Name(id="np", ctx=ast.Load()),
attr=node_type,
ctx=ast.Load(),
),
args=[temp_results[operand1]],
keywords=[],
)
elif node_type == "sin":

# Binary Numpy Functions
elif node_type in ["log"]:
value = ast.Call(
func=ast.Attribute(
value=ast.Name(id="np", ctx=ast.Load()), attr="sin", ctx=ast.Load()
value=ast.Name(id="np", ctx=ast.Load()),
attr=node_type,
ctx=ast.Load(),
),
args=[temp_results[operand1]],
args=[temp_results[operand1], temp_results[operand2]],
keywords=[],
)

Expand Down Expand Up @@ -295,3 +317,8 @@ def generate_function_from_nodes(nodes):
exec(code_obj, namespace)

return namespace["compute"]


def generate_curve_fit_function(process_graph, variables=['x']):
nodes = _format_nodes(process_graph, variables)
return _generate_function_from_nodes(nodes)
123 changes: 123 additions & 0 deletions tests/data/graphs/all_math.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
{
"id": "allmath",
"process_graph": {
"array1": {
"process_id": "array_element",
"arguments": {
"data": [
1
],
"index": 0
}
},
"array2": {
"process_id": "array_element",
"arguments": {
"data": [
2
],
"index": 0
}
},
"add3": {
"process_id": "add",
"arguments": {
"x": {
"from_node": "array1"
},
"y": {
"from_node": "array2"
}
}
},
"subtract4": {
"process_id": "subtract",
"arguments": {
"x": {
"from_node": "add3"
},
"y": 4
}
},
"absolute5": {
"process_id": "absolute",
"arguments": {
"x": {
"from_node": "subtract4"
}
}
},
"divide6": {
"process_id": "divide",
"arguments": {
"x": {
"from_node": "absolute5"
},
"y": 2
}
},
"multiply7": {
"process_id": "multiply",
"arguments": {
"x": {
"from_node": "divide6"
},
"y": 4
}
},
"cos8": {
"process_id": "cos",
"arguments": {
"x": {
"from_node": "multiply7"
}
}
},
"sin9": {
"process_id": "sin",
"arguments": {
"x": {
"from_node": "multiply7"
}
}
},
"divide10": {
"process_id": "divide",
"arguments": {
"x": {
"from_node": "cos8"
},
"y": {
"from_node": "sin9"
}
}
},
"tan11": {
"process_id": "tan",
"arguments": {
"x": {
"from_node": "divide10"
}
}
},
"sqrt13": {
"process_id": "sqrt",
"arguments": {
"x": {
"from_node": "power14"
}
},
"result": true
},
"power14": {
"process_id": "power",
"arguments": {
"base": {
"from_node": "tan11"
},
"p": 2
}
}
},
"parameters": []
}
7 changes: 4 additions & 3 deletions tests/data/graphs/fit_curve.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@
"arrayelement3": {
"process_id": "array_element",
"arguments": {
"data": {
"from_parameter": "parameters"
},
"data": [
3,
2
],
"index": 2
}
},
Expand Down
Loading

0 comments on commit 689ab92

Please sign in to comment.