Skip to content

Commit

Permalink
Final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 9, 2024
1 parent b9cb099 commit 756f274
Show file tree
Hide file tree
Showing 18 changed files with 608 additions and 321 deletions.
14 changes: 8 additions & 6 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BaseVars:
def __init__(self, vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]):
self.vars = vars

def _unpack_to_any(self):
def _unpack_to_any(self) -> tuple[Union[Var, Optional[Var], Sequence[Var]], ...]:
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.vars.values())

Expand All @@ -60,7 +60,7 @@ def flatten_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def __getattr__(self, attr: str):
def __getattr__(self, attr: str) -> Union[Var, Optional[Var], Sequence[Var]]:
"""Retrieves the attribute if present in the stored variables."""
try:
return self.vars[attr]
Expand All @@ -69,7 +69,9 @@ def __getattr__(self, attr: str):
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)

def __setattr__(self, attr: str, value: Union[Var, Sequence[Var]]) -> None:
def __setattr__(
self, attr: str, value: Union[Var, Optional[Var], Sequence[Var]]
) -> None:
"""Sets the attribute to a value if the attribute is present in the stored variables."""
if attr == "vars":
super().__setattr__(attr, value)
Expand Down Expand Up @@ -193,7 +195,7 @@ def _propagate_vars(self, prop_values: Optional[PropDict] = None) -> BaseVars:
if prop_values is None:
prop_values = {}

def _create_var(key, var_info):
def _create_var(key: str, var_info: _VarInfo) -> Var:
ret = Var(var_info, None)

if var_info.type is None or key not in prop_values:
Expand All @@ -212,10 +214,10 @@ def _create_var(key, var_info):

return ret

ret_dict = {}
ret_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {}

for key, var_info in self.__dict__.items():
if var_info is None or isinstance(var_info, _VarInfo):
if isinstance(var_info, _VarInfo):
ret_dict[key] = _create_var(key, var_info)
else:
ret_dict[key] = [
Expand Down
11 changes: 8 additions & 3 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from . import _graph
from ._value_prop import PropDict

DEFAULT_FUNCTION_DOMAIN = "spox.default"

Expand Down Expand Up @@ -51,7 +52,9 @@ class Function(_InternalNode):
func_outputs: BaseOutputs
func_graph: _graph.Graph

def constructor(self, attrs: dict[str, _attributes.Attr], inputs: BaseVars):
def constructor(
self, attrs: dict[str, _attributes.Attr], inputs: BaseVars
) -> BaseOutputs:
"""
Abstract method for functions.
Expand All @@ -64,7 +67,7 @@ def constructor(self, attrs: dict[str, _attributes.Attr], inputs: BaseVars):
f"Function {type(self).__name__} does not implement a constructor."
)

def infer_output_types(self, input_prop_values) -> dict[str, Type]:
def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
from . import _graph

func_args_var = _graph.arguments_dict(
Expand Down Expand Up @@ -164,7 +167,9 @@ class Attributes(BaseAttributes):
Outputs = _FuncOutputs
op_type = OpType(name, domain, version)

def constructor(self, attrs: dict[str, _attributes.Attr], inputs: BaseVars):
def constructor(
self, attrs: dict[str, _attributes.Attr], inputs: BaseVars
) -> BaseOutputs:
return self.Outputs(*unwrap_vars(fun(*inputs.flatten_vars().values())))

return _Func
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var
)
else:
raise TypeError(f"Cannot construct argument from {type(info)}.")
return result
return result # type: ignore


def arguments(**kwargs: Optional[Union[Type, np.ndarray]]) -> tuple[Var, ...]:
Expand Down Expand Up @@ -126,7 +126,7 @@ def initializer(arr: np.ndarray) -> Var:
BaseInputs(),
)
.get_output_vars()
.arg
.arg # type: ignore
)


Expand Down
2 changes: 1 addition & 1 deletion src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def intros(*args: Var) -> Sequence[Var]:
return (
_Introduce(None, _Introduce.Inputs(unwrap_vars(args)), out_variadic=len(args))
.get_output_vars()
.outputs
.outputs # type: ignore
)


Expand Down
17 changes: 12 additions & 5 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import typing
import warnings
from abc import ABC
from collections.abc import Iterable, Sequence
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Any, ClassVar, Optional, Union

Expand All @@ -20,7 +20,14 @@
from ._attributes import AttrGraph
from ._debug import STORE_TRACEBACK
from ._exceptions import InferenceWarning
from ._fields import BaseAttributes, BaseInputs, BaseOutputs, BaseVars, VarFieldKind
from ._fields import (
BaseAttributes,
BaseInputs,
BaseOutputs,
BaseVarInfos,
BaseVars,
VarFieldKind,
)
from ._type_system import Type
from ._value_prop import PropDict
from ._var import _VarInfo
Expand Down Expand Up @@ -183,7 +190,7 @@ def min_output(self) -> int:
def signature(self) -> str:
"""Get a signature of this Node, including its inputs and attributes (but not outputs)."""

def fmt_input(key, var):
def fmt_input(key: str, var: _VarInfo) -> str:
return f"{key}: {var.type}"

sign = ", ".join(
Expand Down Expand Up @@ -228,7 +235,7 @@ def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:

def inference(
self, input_prop_values: Optional[PropDict] = None, infer_types: bool = True
):
) -> None:
if input_prop_values is None:
input_prop_values = {}
# Type inference routine - call infer_output_types if required
Expand Down Expand Up @@ -300,7 +307,7 @@ def _check_concrete_type(self, value_type: Optional[Type]) -> Optional[str]:
return f"{type(e).__name__}: {str(e)}"
return None

def _list_types(self, source):
def _list_types(self, source: BaseVarInfos) -> Iterator[tuple[str, Optional[Type]]]:
return ((key, var.type) for key, var in source.get_var_infos().items())

def _init_output_vars(self) -> BaseOutputs:
Expand Down
2 changes: 1 addition & 1 deletion src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def argument(typ: Type) -> Var:
_internal_op.Argument.Attributes(type=AttrType(typ, "dummy"), default=None)
)
.get_output_vars()
.arg
.arg # type: ignore
)


Expand Down
4 changes: 2 additions & 2 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def to_singleton_onnx_model(
*,
dummy_outputs: bool = True,
with_dummy_subgraphs: bool = True,
input_prop_values: PropDict = {},
input_prop_values: PropDict,
) -> tuple[onnx.ModelProto, Scope]:
"""
Build a singleton model consisting of just this StandardNode. Used for type inference.
Expand Down Expand Up @@ -220,7 +220,7 @@ def propagate_values_onnx(
def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
return self.infer_output_types_onnx(input_prop_values)

def propagate_values(self, input_prop_values) -> dict[str, PropValueType]:
def propagate_values(self, input_prop_values: PropDict) -> dict[str, PropValueType]:
if _value_prop._VALUE_PROP_BACKEND != _value_prop.ValuePropBackend.NONE:
return self.propagate_values_onnx(input_prop_values)
return {}
Expand Down
56 changes: 28 additions & 28 deletions src/spox/_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self._op = op
self._name = None

def _rename(self, name: Optional[str]):
def _rename(self, name: Optional[str]) -> None:
"""Mutates the internal state of the VarInfo, overriding its name as given."""
self._name = name

Expand Down Expand Up @@ -107,7 +107,7 @@ def __copy__(self) -> _VarInfo:
# during the build process
return self

def __deepcopy__(self, _) -> _VarInfo:
def __deepcopy__(self, _: Any) -> _VarInfo:
raise ValueError("'VarInfo' objects cannot be deepcopied.")


Expand Down Expand Up @@ -208,84 +208,84 @@ def unwrap_optional(self) -> _type_system.Optional:
return self.unwrap_type().unwrap_optional()

@property
def _op(self):
def _op(self) -> Node:
return self._var_info._op

@property
def _name(self):
def _name(self) -> Optional[str]:
return self._var_info._name

def _rename(self, name: Optional[str]):
def _rename(self, name: Optional[str]) -> None:
self._var_info._rename(name)

@property
def _which_output(self):
def _which_output(self) -> Optional[str]:
return self._var_info._which_output

@property
def type(self):
def type(self) -> Optional[_type_system.Type]:
return self._var_info.type

def __copy__(self) -> Var:
# Simply return `self` to ensure that "copies" are still equal
# during the build process
return self

def __deepcopy__(self, _) -> Var:
def __deepcopy__(self, _: Any) -> Var:
raise ValueError("'Var' objects cannot be deepcopied.")

def __add__(self, other) -> Var:
def __add__(self, other: Var) -> Var:
return Var._operator_dispatcher.add(self, other)

def __sub__(self, other) -> Var:
def __sub__(self, other: Var) -> Var:
return Var._operator_dispatcher.sub(self, other)

def __mul__(self, other) -> Var:
def __mul__(self, other: Var) -> Var:
return Var._operator_dispatcher.mul(self, other)

def __truediv__(self, other) -> Var:
def __truediv__(self, other: Var) -> Var:
return Var._operator_dispatcher.truediv(self, other)

def __floordiv__(self, other) -> Var:
def __floordiv__(self, other: Var) -> Var:
return Var._operator_dispatcher.floordiv(self, other)

def __neg__(self) -> Var:
return Var._operator_dispatcher.neg(self)

def __and__(self, other) -> Var:
def __and__(self, other: Var) -> Var:
return Var._operator_dispatcher.and_(self, other)

def __or__(self, other) -> Var:
def __or__(self, other: Var) -> Var:
return Var._operator_dispatcher.or_(self, other)

def __xor__(self, other) -> Var:
def __xor__(self, other: Var) -> Var:
return Var._operator_dispatcher.xor(self, other)

def __invert__(self) -> Var:
return Var._operator_dispatcher.not_(self)

def __radd__(self, other) -> Var:
def __radd__(self, other: Var) -> Var:
return Var._operator_dispatcher.add(other, self)

def __rsub__(self, other) -> Var:
def __rsub__(self, other: Var) -> Var:
return Var._operator_dispatcher.sub(other, self)

def __rmul__(self, other) -> Var:
def __rmul__(self, other: Var) -> Var:
return Var._operator_dispatcher.mul(other, self)

def __rtruediv__(self, other) -> Var:
def __rtruediv__(self, other: Var) -> Var:
return Var._operator_dispatcher.truediv(other, self)

def __rfloordiv__(self, other) -> Var:
def __rfloordiv__(self, other: Var) -> Var:
return Var._operator_dispatcher.floordiv(other, self)

def __rand__(self, other) -> Var:
def __rand__(self, other: Var) -> Var:
return Var._operator_dispatcher.and_(other, self)

def __ror__(self, other) -> Var:
def __ror__(self, other: Var) -> Var:
return Var._operator_dispatcher.or_(other, self)

def __rxor__(self, other) -> Var:
def __rxor__(self, other: Var) -> Var:
return Var._operator_dispatcher.xor(other, self)


Expand All @@ -306,10 +306,10 @@ def wrap_vars(var_info: dict[T, _VarInfo]) -> dict[T, Var]: ... # type: ignore[


@overload
def wrap_vars(var_info: Union[Sequence[_VarInfo], Iterable[_VarInfo]]) -> list[Var]: ...
def wrap_vars(var_info: Iterable[_VarInfo]) -> list[Var]: ...


def wrap_vars(var_info):
def wrap_vars(var_info): # type: ignore
if var_info is None:
return None
elif isinstance(var_info, _VarInfo):
Expand All @@ -335,10 +335,10 @@ def unwrap_vars(var: dict[T, Var]) -> dict[T, _VarInfo]: ... # type: ignore[ove


@overload
def unwrap_vars(var: Union[Iterable[Var]]) -> list[_VarInfo]: ...
def unwrap_vars(var: Iterable[Var]) -> list[_VarInfo]: ...


def unwrap_vars(var):
def unwrap_vars(var): # type: ignore
if var is None:
return None
elif isinstance(var, Var):
Expand Down
Loading

0 comments on commit 756f274

Please sign in to comment.