Skip to content

Commit

Permalink
Partially fix backend vstack/hstack
Browse files Browse the repository at this point in the history
  • Loading branch information
pablormier committed May 21, 2024
1 parent ae0f2cc commit 5d4e58c
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 131 deletions.
56 changes: 4 additions & 52 deletions corneto/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,68 +28,20 @@ def _wrapped_func(*args, **kwargs):


def _delegate(func):
"""A decorator that wraps a function to provide extended functionality
when applied within a class. This decorator modifies the behavior
of the function `func` to handle expression objects and delegate
calls to their underlying representations, while maintaining a set of
symbols associated with the expression objects.
The primary use of this decorator is to allow mathematical and
operational transformations on proxy objects (like expressions in a
symbolic or algebraic framework) that abstract underlying complex
behaviors (like algebraic expressions handled by a computational backend
such as PICOS or CVXPY).
Parameters:
func (Callable): The function to be wrapped. This function should be
a method of a class that handles expressions. It is
expected to operate on instances of the class and
potentially other similar objects.
Returns:
Callable: A wrapper function `_wrapper_func` that takes the same arguments as `func`.
This function intercepts calls to `func`, updates and manages symbols,
and delegates operations to the underlying computational backend if possible.
Decorators:
@wraps(func): This decorator is used to preserve the name, docstring, and other
attributes of the original function `func`.
Usage:
To use this decorator, apply it to methods in a class that represents expressions,
where such methods need to interact with the underlying computational or symbolic
representation of those expressions. The decorator handles conversion and delegation
logic, facilitating the interaction with more complex backends transparently.
Example:
```python
class Expression:
def _create(self, expr, symbols):
# Implementation details...
pass
@_delegate
def __add__(self, other):
# Additional functionality can be inserted here.
pass
```
"""

@wraps(func)
def _wrapper_func(self, *args, **kwargs):
symbols = set()
# if hasattr(self, '_proxy_symbols'):
# symbols.update(self._proxy_symbols)
# if getattr(self, 'is_symbol', lambda: False)():
# symbols.add(self)
if len(args) > 0:
# Function is providing 'other' expression
if hasattr(args[0], "_expr"):
args = list(args)
symbols.update(args[0]._proxy_symbols)
if getattr(args[0], "is_symbol", lambda: False)():
symbols.add(args[0])
args[0] = args[0]._expr # symbol is lost
# Extract the original backend symbol
args[0] = args[0]._expr
# Attach the list of original symbols to the backend expression
setattr(args[0], "_proxy_symbols", symbols)
if hasattr(self._expr, func.__name__):
# Check if its callable
f = getattr(self._expr, func.__name__)
Expand Down
71 changes: 67 additions & 4 deletions corneto/backend/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,59 @@ def _vstack(self, other: "CExpression") -> Any:
def vstack(self, other: "CExpression") -> "CExpression":
return self._vstack(other)

@abc.abstractmethod
def _reshape(self, shape: Tuple[int, ...]) -> "CExpression":
pass

@_delegate
def reshape(self, shape: Union[int, Tuple[int, ...]]) -> "CExpression":
this_shape = self.shape
num_elements = 1
for dim in this_shape:
num_elements *= dim

# Convert single int shape to tuple
if isinstance(shape, int):
shape = (shape,)

# Validate the input shape
if shape.count(-1) > 1:
raise ValueError("Only one dimension can be -1")
if any(dim < -1 for dim in shape):
raise ValueError("Invalid shape: dimensions must be positive or -1")

# Handle the case where shape is (-1,) or -1 to flatten the array
if shape == (-1,):
return self._reshape((num_elements,))

# General case: if -1 is present, calculate the corresponding dimension
if -1 in shape:
new_shape = []
unknown_index = shape.index(-1)
known_size = 1

for i, dim in enumerate(shape):
if i != unknown_index:
known_size *= dim
new_shape.append(dim)

# Check that total elements match
if num_elements % known_size != 0:
raise ValueError("The total size of the new array must be unchanged")

new_shape[unknown_index] = num_elements // known_size
return self._reshape(tuple(new_shape))

# Check total size is ok
new_num_elements = 1
for dim in shape:
new_num_elements *= dim

if new_num_elements != num_elements:
raise ValueError("The total size of the new array must be unchanged")

return self._reshape(shape)

@abc.abstractmethod
def _norm(self, p: int = 2) -> Any:
pass
Expand All @@ -149,6 +202,11 @@ def _max(self, axis: Optional[int] = None) -> Any:
def max(self, axis: Optional[int] = None) -> "CExpression":
return self._max(axis=axis)

# These delegated methods are invoked directly in the backend
# and wrapped thanks to the _delegate decorator. If a new
# backend has a different behavior, provide an abstract method
# as in the previous cases.

@_delegate
def __getitem__(self, item) -> "CExpression": # type: ignore
pass
Expand Down Expand Up @@ -283,6 +341,7 @@ def __init__(
ub_r: Optional[np.ndarray] = None
self._provided_lb = lb
self._provided_ub = ub
setattr(expr, "_csymbol_shape", shape)

if shape is None:
shape = () # type: ignore
Expand Down Expand Up @@ -1214,7 +1273,9 @@ def Xor(self, x: CExpression, y: CExpression, varname="_xor"):
[xor >= x - y, xor >= y - x, xor <= x + y, xor <= 2 - x - y]
)

def linear_or(self, x: CExpression, axis: Optional[int] = None, varname="or"):
def linear_or(
self, x: CExpression, axis: Optional[int] = None, varname="or"
) -> ProblemDef:
# Check if the variable has a vartype and is binary
if hasattr(x, "_vartype") and x._vartype != VarType.BINARY:
raise ValueError(f"Variable x has type {x._vartype} instead of BINARY")
Expand All @@ -1233,7 +1294,9 @@ def linear_or(self, x: CExpression, axis: Optional[int] = None, varname="or"):
Or = self.Variable(varname, Z.shape, 0, 1, vartype=VarType.BINARY)
return self.Problem([Or >= Z_norm, Or <= Z])

def linear_and(self, x: CExpression, axis: Optional[int] = None, varname="and"):
def linear_and(
self, x: CExpression, axis: Optional[int] = None, varname="and"
) -> ProblemDef:
# Check if the variable is binary, otherwise throw an error
if hasattr(x, "_vartype") and x._vartype != VarType.BINARY:
raise ValueError(f"Variable x has type {x._vartype} instead of BINARY")
Expand All @@ -1251,7 +1314,7 @@ def linear_and(self, x: CExpression, axis: Optional[int] = None, varname="and"):
And = self.Variable(varname, Z.shape, 0, 1, vartype=VarType.BINARY)
return self.Problem([And <= Z_norm, And >= Z - N + 1])

def vstack(self, arg_list: Iterable[CSymbol]):
def vstack(self, arg_list: Iterable[CExpression]) -> CExpression:
v = None
for a in arg_list:
if v is None:
Expand All @@ -1260,7 +1323,7 @@ def vstack(self, arg_list: Iterable[CSymbol]):
v = v.vstack(a)
return v

def hstack(self, arg_list: Iterable[CSymbol]):
def hstack(self, arg_list: Iterable[CExpression]) -> CExpression:
h = None
for a in arg_list:
if h is None:
Expand Down
27 changes: 22 additions & 5 deletions corneto/backend/_cvxpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,28 @@ def _sum(self, axis: Optional[int] = None) -> Any:
def _max(self, axis: Optional[int] = None) -> Any:
return cp.max(self._expr, axis=axis)

def _hstack(self, other: CExpression) -> Any:
return cp.hstack([self._expr, other])

def _vstack(self, other: CExpression) -> Any:
return cp.vstack([self._expr, other])
def _hstack(self, other: Any) -> Any:
a = self._expr
b = other
# If vector, for hstack assume is a column vector
# if len(a.shape) == 1:
# a = cp.reshape(a, (a.shape[0], 1))
# if len(b.shape) == 1:
# b = cp.reshape(b, (b.shape[0], 1))
return cp.hstack([a, b])

def _vstack(self, other: Any) -> Any:
a = self._expr
b = other
# If vector, for vstack assume is a row vector
if len(a.shape) == 1:
a = cp.reshape(a, (1, a.shape[0]))
if len(b.shape) == 1:
b = cp.reshape(b, (1, b.shape[0]))
return cp.vstack([a, b])

def _reshape(self, shape: Tuple[int, ...]) -> Any:
return cp.reshape(self._expr, shape)

@property
def value(self) -> np.ndarray:
Expand Down
62 changes: 57 additions & 5 deletions corneto/backend/_picos_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,35 @@
pc = None


def _get_shape(a: Any):
if hasattr(a, "_csymbol_shape"):
return a._csymbol_shape
if hasattr(a, "_expr"):
return (
a._expr._csymbol_shape
if hasattr(a._expr, "_csymbol_shape")
else a._expr.shape
)
if hasattr(a, "shape"):
return a.shape
return ()


def _infer_shape(c: Any):
shape = _get_shape(c)
if len(shape) == 2 and (shape[0] == 1 or shape[1] == 1):
# This is the problematic case, as any transformation
# using PICOS will result in a 2D array.
if hasattr(c, "_proxy_symbols"):
for s in c._proxy_symbols:
# If there is some original symbol used in the expression
# with ndim 1, we assume the original shape was 1D.
s_shape = _get_shape(s)
if len(s_shape) == 1:
return (shape[0],)
return shape


class PicosExpression(CExpression):
def __init__(self, expr: Any, symbols: Optional[Set["CSymbol"]] = None) -> None:
super().__init__(expr, symbols)
Expand All @@ -40,11 +69,34 @@ def _sum(self, axis: Optional[int] = None) -> Any:
def _max(self, axis: Optional[int] = None) -> Any:
raise NotImplementedError()

def _hstack(self, other: CExpression) -> Any:
return self._expr & other

def _vstack(self, other: CExpression) -> Any:
return self._expr // other
def _hstack(self, other: Any) -> Any:
a = self._expr
b = other
# PICOS assumes vectors are column vectors.
# For hstack, if 1 dim, we assume is a row vector.
a_shape = _infer_shape(self)
b_shape = _infer_shape(other)
if len(a_shape) == 1 and a.shape[1] == 1:
a = a.T
if len(b_shape) == 1 and b.shape[1] == 1:
b = b.T
return a & b

def _vstack(self, other: Any) -> Any:
a = self._expr
b = other
# PICOS assumes vectors are column vectors.
# We need to keep track of the original dim.
a_shape = _infer_shape(self)
b_shape = _infer_shape(other)
if len(a_shape) == 1:
a = a.reshaped((1, a_shape[0]))
if len(b_shape) == 1:
b = b.reshaped((1, b_shape[0]))
return a // b

def _reshape(self, shape: Tuple[int, ...]) -> Any:
return self._expr.reshaped(shape)

@property
def value(self) -> np.ndarray:
Expand Down
Loading

0 comments on commit 5d4e58c

Please sign in to comment.