Skip to content

Commit

Permalink
Resolve jit issues due to Python conditionals
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Nov 8, 2023
1 parent 08a5896 commit 3c8a9d5
Showing 1 changed file with 44 additions and 25 deletions.
69 changes: 44 additions & 25 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ def prox(
classes.
"""
norm_v = norm(v)
if norm_v == 0:
return 0 * v
return snp.maximum(1 - lam / norm_v, 0) * v
return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v)


class L21Norm(Functional):
Expand Down Expand Up @@ -283,16 +281,48 @@ def __init__(self, beta: float = 1.0):
def __call__(self, x: Union[Array, BlockArray]) -> float:
return snp.sum(snp.abs(x)) - self.beta * norm(x)

@staticmethod
def _prox_vamx_ge_thresh(v, va, vs, alpha, beta):
u = snp.zeros(v.shape, dtype=v.dtype)
idx = va.ravel().argmax()
u = (
u.ravel().at[idx].set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx])
).reshape(v.shape)
return u

@staticmethod
def _prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx < (1.0 - beta) * alpha,
snp.zeros(v.shape, dtype=v.dtype),
L1MinusL2Norm._prox_vamx_ge_thresh(v, va, vs, alpha, beta),
)

@staticmethod
def _prox_vamx_gt_alpha(v, va, vs, alpha, beta):
u = snp.maximum(va - alpha, 0.0) * vs
l2u = norm(u)
u *= (l2u + alpha * beta) / l2u
return u

@staticmethod
def _prox_vamx_gt_0(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx > alpha,
L1MinusL2Norm._prox_vamx_gt_alpha(v, va, vs, alpha, beta),
L1MinusL2Norm._prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta),
)

def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms
r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms.
Evaluate the proximal operator of the difference of :math:`\ell_1`
and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x} \|_1 -
\beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note that this
is not a proximal operator according to the strict definition since
the loss function is non-convex.
and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x}
\|_1 - \beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note
that this is not a proximal operator according to the strict
definition since the loss function is non-convex.
Args:
v: Input array :math:`\mb{v}`.
Expand All @@ -308,23 +338,12 @@ def prox(
vs = snp.exp(1j * snp.angle(v))
else:
vs = snp.sign(v)
if vamx > 0.0:
if vamx > alpha:
u = snp.maximum(va - alpha, 0.0) * vs
l2u = norm(u)
u *= (l2u + alpha * beta) / l2u
else:
u = snp.zeros(v.shape, dtype=v.dtype)
if vamx >= (1.0 - beta) * alpha:
idx = va.ravel().argmax()
u = (
u.ravel()
.at[idx]
.set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx])
).reshape(v.shape)
else:
u = snp.zeros(v.shape, dtype=v.dtype)
return u

return snp.where(
vamx > 0.0,
L1MinusL2Norm._prox_vamx_gt_0(v, va, vs, vamx, alpha, beta),
snp.zeros(v.shape, dtype=v.dtype),
)


class HuberNorm(Functional):
Expand Down

0 comments on commit 3c8a9d5

Please sign in to comment.