diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 332d0500f..ba8ec407f 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -180,9 +180,7 @@ def prox( classes. """ norm_v = norm(v) - if norm_v == 0: - return 0 * v - return snp.maximum(1 - lam / norm_v, 0) * v + return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v) class L21Norm(Functional): @@ -283,16 +281,48 @@ def __init__(self, beta: float = 1.0): def __call__(self, x: Union[Array, BlockArray]) -> float: return snp.sum(snp.abs(x)) - self.beta * norm(x) + @staticmethod + def _prox_vamx_ge_thresh(v, va, vs, alpha, beta): + u = snp.zeros(v.shape, dtype=v.dtype) + idx = va.ravel().argmax() + u = ( + u.ravel().at[idx].set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx]) + ).reshape(v.shape) + return u + + @staticmethod + def _prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta): + return snp.where( + vamx < (1.0 - beta) * alpha, + snp.zeros(v.shape, dtype=v.dtype), + L1MinusL2Norm._prox_vamx_ge_thresh(v, va, vs, alpha, beta), + ) + + @staticmethod + def _prox_vamx_gt_alpha(v, va, vs, alpha, beta): + u = snp.maximum(va - alpha, 0.0) * vs + l2u = norm(u) + u *= (l2u + alpha * beta) / l2u + return u + + @staticmethod + def _prox_vamx_gt_0(v, va, vs, vamx, alpha, beta): + return snp.where( + vamx > alpha, + L1MinusL2Norm._prox_vamx_gt_alpha(v, va, vs, alpha, beta), + L1MinusL2Norm._prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta), + ) + def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: - r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms + r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms. Evaluate the proximal operator of the difference of :math:`\ell_1` - and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x} \|_1 - - \beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note that this - is not a proximal operator according to the strict definition since - the loss function is non-convex. + and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x} + \|_1 - \beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note + that this is not a proximal operator according to the strict + definition since the loss function is non-convex. Args: v: Input array :math:`\mb{v}`. @@ -308,23 +338,12 @@ def prox( vs = snp.exp(1j * snp.angle(v)) else: vs = snp.sign(v) - if vamx > 0.0: - if vamx > alpha: - u = snp.maximum(va - alpha, 0.0) * vs - l2u = norm(u) - u *= (l2u + alpha * beta) / l2u - else: - u = snp.zeros(v.shape, dtype=v.dtype) - if vamx >= (1.0 - beta) * alpha: - idx = va.ravel().argmax() - u = ( - u.ravel() - .at[idx] - .set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx]) - ).reshape(v.shape) - else: - u = snp.zeros(v.shape, dtype=v.dtype) - return u + + return snp.where( + vamx > 0.0, + L1MinusL2Norm._prox_vamx_gt_0(v, va, vs, vamx, alpha, beta), + snp.zeros(v.shape, dtype=v.dtype), + ) class HuberNorm(Functional): diff --git a/scico/test/functional/test_misc.py b/scico/test/functional/test_misc.py index 2905facd1..1edcd0174 100644 --- a/scico/test/functional/test_misc.py +++ b/scico/test/functional/test_misc.py @@ -13,7 +13,11 @@ class TestCheckAttrs: # and set to True/False in the Functional subclasses. # Generate a list of all functionals in scico.functionals that we will check - ignore = [functional.Functional, functional.ScaledFunctional, functional.SeparableFunctional] + ignore = [ + functional.Functional, + functional.ScaledFunctional, + functional.SeparableFunctional, + ] to_check = [] for name, cls in functional.__dict__.items(): if isinstance(cls, type): @@ -30,6 +34,44 @@ def test_has_prox(self, cls): assert isinstance(cls.has_prox, bool) +class TestJit: + # Test whether functionals can be jitted. + + # Generate a list of all functionals in scico.functionals that we will check + ignore = [ + functional.Functional, + functional.ScaledFunctional, + functional.SeparableFunctional, + functional.BM3D, + functional.BM4D, + ] + to_check = [] + for name, cls in functional.__dict__.items(): + if isinstance(cls, type): + if issubclass(cls, functional.Functional): + if cls not in ignore: + to_check.append(cls) + + @pytest.mark.parametrize("cls", to_check) + def test_jit(self, cls): + # Only test functionals that have no required __init__ parameters. + try: + f = cls() + except TypeError: + pass + else: + v = snp.arange(4.0) + # Only test functionals that can take 1D input. + try: + u0 = f.prox(v) + except ValueError: + pass + else: + fprox = jax.jit(f.prox) + u1 = fprox(v) + assert np.allclose(u0, u1) + + def test_functional_sum(): x = np.random.randn(4, 4) f0 = functional.L1Norm()