Skip to content

Commit

Permalink
Improve qol
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 9, 2024
1 parent 1b03440 commit 8e5d25a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def init(*args: Var) -> type[Function]:
def alt_fun(*args: Var) -> Iterable[Union[Var, Optional[Var], Sequence[Var]]]:
cls = init(*args)
return [
Var(var_info) # type: ignore
Var(var_info)
for var_info in cls(cls.Attributes(), cls.Inputs(*unwrap_vars(args)))
.outputs.get_fields()
.outputs.get_var_infos()
.values()
]

Expand Down
9 changes: 6 additions & 3 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from __future__ import annotations

from collections.abc import Iterable
from typing import TYPE_CHECKING, Callable

import numpy as np
import onnx
import onnx.reference
import onnx.shape_inference
Expand Down Expand Up @@ -121,16 +123,17 @@ def out_value_info(
elif not isinstance(prop, PropValue) or prop.value is None:
continue
elif isinstance(prop.type, Sequence):
assert isinstance(prop.value, Iterable)
initializers.extend(
[
from_array(elem.value, f"{name}_{i}")
for i, elem in enumerate(prop.value) # type: ignore
for i, elem in enumerate(prop.value)
if elem is not None
]
)
else:
initializers.append(from_array(prop.value, name)) # type: ignore
continue
assert isinstance(prop.value, np.ndarray)
initializers.append(from_array(prop.value, name))

# Graph and model
graph = onnx.helper.make_graph(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_value_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_empty_optional_has_no_element():

@pytest.mark.parametrize("min", [None, 2])
def test_optional_clip(min):
min_var = min if min is None else op.const(min)
min_var = None if min is None else op.const(min)
assert_equal_value(
op.clip(op.const([1, 2, 3]), min=min_var, max=op.const(3)),
np.clip([1, 2, 3], a_min=min, a_max=3),
Expand Down

0 comments on commit 8e5d25a

Please sign in to comment.