Skip to content

Commit

Permalink
add graph functions to edit an existing onnx model
Browse files Browse the repository at this point in the history
  • Loading branch information
MainRo committed Oct 24, 2024
1 parent 338765b commit 68ec987
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 33 deletions.
2 changes: 1 addition & 1 deletion ebm2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,5 @@ def to_onnx(model, dtype, name="ebm",
else:
raise NotImplementedError("{} models are not supported".format(type(model)))

model = graph.compile(g, target_opset, name=name)
model = graph.to_onnx(g, target_opset, name=name)
return model
108 changes: 100 additions & 8 deletions ebm2onnx/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import NamedTuple, Callable, List
from typing import NamedTuple, Callable, Optional, List, Dict

import onnx
from ebm2onnx import __version__
from .utils import get_latest_opset_version


class Graph(NamedTuple):
Expand All @@ -14,18 +16,17 @@ class Graph(NamedTuple):

def create_name_generator() -> Callable[[str], str]:
state = {}

def _generate_unique_name(name: str) -> str:
""" Generates a new globaly unique name in the graph
"""
if name in state:
i = state[name]
state[name] += 1
else:
state[name] = 0

return "{}_{}".format(name, state[name])



return _generate_unique_name


Expand All @@ -41,12 +42,56 @@ def pipe(*args):
pass


def create_graph():
def create_graph() -> Graph:
"""Creates a new graph object.
Returns:
A Graph object.
"""
return Graph(
generate_name=create_name_generator()
)

def compile(graph, target_opset, name="ebm"):

def from_onnx(model) -> Graph:
"""Creates a graph object from an onnx model.
Creating a graph from an existing model allows for editing it.
Args:
model: An ONNX model
Returns:
A Graph object.
"""
return Graph(
generate_name=create_name_generator(),
inputs=[n for n in model.graph.input],
outputs=[n for n in model.graph.output],
nodes=[n for n in model.graph.node],
initializers=[n for n in model.graph.initializer],
)


def to_onnx(
graph: Graph,
target_opset: Optional[int | Dict[str, int]] = None,
name: Optional[str] = "ebm",
) -> Graph:
"""Converts a graph to an onnx model.
If target_opset is an int, then is corresponds to the default domain
'ai.onnx'. Using a dict allows to set opset versions for other domains
like 'ai.onnx.ml'.
Args:
graph: The graph object
target_opset: the target opset to use when converting ot onnx, can be an int or a dict
name: [Optional] An existing ONNX model
Returns:
A Graph object.
"""
#outputs = graph.transients

graph = onnx.helper.make_graph(
Expand All @@ -57,9 +102,35 @@ def compile(graph, target_opset, name="ebm"):
initializer=graph.initializers,
)
model = onnx.helper.make_model(graph, producer_name='ebm2onnx')
model.opset_import[0].version = target_opset

#producer_name = "interpretml/ebm2onnx"
#producer_version = __version__

#domain
#model_version
#doc_string

#metadata_props

# set opset versions
if target_opset is not None:
if type(target_opset) is int:
model.opset_import[0].version = target_opset
elif type(target_opset) is dict:
del model.opset_import[:]

for k, v in target_opset.items():
opset = model.opset_import.add()
opset.domain = k
opset.model = v
else:
raise ValueError(f"ebm2onnx.graph.to_onnx: invalid type for target_opset: {type(target_opset)}.")
else:
model.opset_import[0].version = get_latest_opset_version()

return model


def create_input(graph, name, type, shape):
input = onnx.helper.make_tensor_value_info(name , type, shape)
return Graph(
Expand All @@ -85,6 +156,27 @@ def create_initializer(graph, name, type, shape, value):
)


def create_transient_by_name(g, name, type, shape):
input = onnx.helper.make_tensor_value_info(name, type, shape)
return Graph(
generate_name=g.generate_name,
transients=[input],
)


def add_transient_by_name(g, name, type=onnx.TensorProto.UNDEFINED, shape=[]):
tname = [
o
for n in g.nodes
for o in n.output
if o == name
][0]
t = onnx.helper.make_tensor_value_info(tname, type, shape)
return g._replace(
transients=extend(g.transients, [t])
)


def strip_to_transients(graph):
""" Returns only the transients of a graph
"""
Expand Down
48 changes: 33 additions & 15 deletions ebm2onnx/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _add(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(add_result_name, g.transients[0].type.tensor_type.elem_type, []),
onnx.helper.make_tensor_value_info(add_result_name, g.transients[0].type.tensor_type.elem_type, []),
],
)

Expand All @@ -33,7 +33,7 @@ def _argmax(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(argmax_result_name, onnx.TensorProto.INT64, []),
onnx.helper.make_tensor_value_info(argmax_result_name, onnx.TensorProto.INT64, []),
],
)

Expand All @@ -50,7 +50,7 @@ def _cast(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(cast_result_name, to, []),
onnx.helper.make_tensor_value_info(cast_result_name, to, []),
],
)

Expand All @@ -69,7 +69,7 @@ def _concat(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(concat_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(concat_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -86,7 +86,7 @@ def _expand(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(expand_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(expand_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -104,7 +104,7 @@ def _flatten(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(flatten_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(flatten_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -126,7 +126,7 @@ def _gather_elements(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(gather_elements_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(gather_elements_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -148,13 +148,30 @@ def _gather_nd(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(gather_nd_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(gather_nd_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

return _gather_nd


def greater_or_equal():
def _greater_or_equal(g):
greater_or_equal_result_name = g.generate_name('greater_or_equal_result')
nodes = [
onnx.helper.make_node("GreaterOrEqual", [g.transients[0].name, g.transients[1].name], [greater_or_equal_result_name]),
]

return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(greater_or_equal_result_name, onnx.TensorProto.BOOL, []),
],
)

return _greater_or_equal


def identity(name, suffix=True):
def _identity(g):
identity_name = g.generate_name(name) if suffix else name
Expand All @@ -165,7 +182,7 @@ def _identity(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(identity_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(identity_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -182,7 +199,7 @@ def _less(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(less_result_name, onnx.TensorProto.BOOL, []),
onnx.helper.make_tensor_value_info(less_result_name, onnx.TensorProto.BOOL, []),
],
)

Expand All @@ -199,12 +216,13 @@ def _less_or_equal(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(less_or_equal_result_name, onnx.TensorProto.BOOL, []),
onnx.helper.make_tensor_value_info(less_or_equal_result_name, onnx.TensorProto.BOOL, []),
],
)

return _less_or_equal


def mul():
def _mul(g):
mul_result_name = g.generate_name('mul_result')
Expand All @@ -215,7 +233,7 @@ def _mul(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(mul_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(mul_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -239,7 +257,7 @@ def _reduce_sum(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(reduce_sum_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(reduce_sum_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -261,7 +279,7 @@ def _reshape(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(reshape_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(reshape_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand All @@ -283,7 +301,7 @@ def _softmax(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(softmax_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(softmax_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand Down
2 changes: 1 addition & 1 deletion ebm2onnx/operators_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _category_mapper(g):
return g._replace(
nodes=graph.extend(g.nodes, nodes),
transients=[
onnx.helper.make_tensor_value_info(category_mapper_result_name, onnx.TensorProto.UNDEFINED, []),
onnx.helper.make_tensor_value_info(category_mapper_result_name, onnx.TensorProto.UNDEFINED, []),
],
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_get_bin_index_on_categorical_value():
})(i)
g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64, [None, 1])

result = infer_model(graph.compile(g, target_opset=13),
result = infer_model(graph.to_onnx(g, target_opset=13),
input={
'i': [["biz"], ["foo"], ["bar"], ["nan"], ["okif"]],
}
Expand Down
12 changes: 12 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def test_merge():
]


def test_create_transient_by_name():
g = graph.create_graph()

init = graph.create_transient_by_name(g, "foo", onnx.TensorProto.FLOAT, [4])
assert len(init.transients) == 1
assert init.transients == [onnx.helper.make_tensor_value_info(
'foo' ,
onnx.TensorProto.FLOAT,
[4],
)]


def test_strip_to_transients():
g = graph.create_graph()

Expand Down
29 changes: 29 additions & 0 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,35 @@ def test_flatten():
)


def test_greater_or_equal():
g = graph.create_graph()

a = graph.create_initializer(g, "a", onnx.TensorProto.FLOAT, [4], [0.1, 2.3, 3.55, 9.6])
b = graph.create_input(g, "b", onnx.TensorProto.FLOAT, [None, 1])

l = ops.greater_or_equal()(graph.merge(b, a))
l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.BOOL, [None, 4])

assert_model_result(l,
input={
'b': [
[0.5],
[1.2],
[11],
[4.2],
[np.NaN],
]
},
expected_result=[[
[True, False, False, False],
[True, False, False, False],
[True, True, True, True],
[True, True, True, False],
[False, False, False, False],
]]
)


def test_less():
g = graph.create_graph()

Expand Down
Loading

0 comments on commit 68ec987

Please sign in to comment.