From c8efe90beeb7e73311190c27c8fc9eb746d8d506 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 1 Nov 2023 21:47:43 -0600 Subject: [PATCH 1/8] Implementation supporting arbitrary dimensional inputs --- scico/functional/__init__.py | 4 +- scico/functional/_norm.py | 107 ----------------------------------- 2 files changed, 2 insertions(+), 109 deletions(-) diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index 377e29b1a..48509cd40 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -20,14 +20,15 @@ L21Norm, NuclearNorm, L1MinusL2Norm, - TV2DNorm, ) +from ._tvnorm import AnisotropicTVNorm from ._indicator import NonNegativeIndicator, L2BallIndicator from ._denoiser import BM3D, BM4D, DnCNN from ._dist import SetDistance, SquaredSetDistance __all__ = [ + "AnisotropicTVNorm", "Functional", "ScaledFunctional", "SeparableFunctional", @@ -47,7 +48,6 @@ "BM3D", "BM4D", "DnCNN", - "TV2DNorm", ] # Imported items in __all__ appear to originate in top-level functional module diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 933062bb0..5d0a6579e 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -479,110 +479,3 @@ def prox( svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False) svdS = snp.maximum(0, svdS - lam) return svdU @ snp.diag(svdS) @ svdV - - -class TV2DNorm(Functional): - r"""The anisotropic total variation (TV) norm. - - For a :math:`M \times N` matrix, :math:`\mb{A}`, by default, - - .. math:: - \norm{\mb{A}}_{\text{TV}} = \sum_{n=1}^N \sum_{m=1}^M - \abs{\nabla{A}_{m,n}} \;. - - The proximal operator of this norm is currently only defined for 2 - dimensional data. - - For `BlockArray` inputs, the TV norm follows the reduction rules - described in :class:`BlockArray`. - """ - - has_eval = True - has_prox = True - - def __init__(self, tau: float = 1.0): - r""" - Args: - tau: Parameter :math:`\tau` in the norm definition. - """ - self.tau = tau - - def __call__(self, x: Union[Array, BlockArray]) -> float: - r"""Return the TV norm of an array.""" - y = 0 - gradOp = FiniteDifference(x.shape, input_dtype=x.dtype, circular=True) - grads = gradOp @ x - for g in grads: - y += snp.abs(g) - return self.tau * snp.sum(y) - - def prox( - self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs - ) -> Union[Array, BlockArray]: - r"""Proximal operator of the TV norm. - - Approximate the proximal operator of the anisotropic TV norm via - the method described in :cite:`kamilov-2016-parallel`. - - Args: - v: Input array :math:`\mb{v}`. - lam: Proximal parameter :math:`\lam`. - kwargs: Additional arguments that may be used by derived - classes. - """ - D = 2 - K = 2 * D - thresh = snp.sqrt(2) * K * self.tau * lam - - y = snp.zeros_like(v) - for ax in range(2): - y = y.at[:].add( - self.iht2( - self.ht2_shrink(v, axis=ax, shift=False, thresh=thresh), axis=ax, shift=False - ) - ) - y = y.at[:].add( - self.iht2( - self.ht2_shrink(v, axis=ax, shift=True, thresh=thresh), axis=ax, shift=True - ) - ) - y = y.at[:].divide(K) - return y - - def ht2_shrink(self, x, axis, shift, thresh): - r"""Forward Discrete Haar Wavelet transform in 2D""" - w = snp.zeros_like(x) - C = 1 / snp.sqrt(2) - if shift: - x = snp.roll(x, -1, axis=axis) - - m = x.shape[axis] // 2 - if not axis: - w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :])) - w = w.at[m:, :].set(self.shrink(C * (x[1::2, :] - x[::2, :]), thresh)) - else: - w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) - w = w.at[:, m:].set(self.shrink(C * (x[:, 1::2] - x[:, ::2]), thresh)) - return w - - def iht2(self, w, axis, shift): - r"""Inverse Discrete Haar Wavelet transform in 2D""" - y = snp.zeros_like(w) - C = 1 / snp.sqrt(2) - m = w.shape[axis] // 2 - if not axis: - y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :])) - y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :])) - else: - y = y.at[:, ::2].set(C * (w[:, :m] - w[:, m:])) - y = y.at[:, 1::2].set(C * (w[:, :m] + w[:, m:])) - - if shift: - y = snp.roll(y, 1, axis) - return y - - def shrink(self, x, tau): - r"""Wavelet shrinkage operator""" - threshed = snp.maximum(snp.abs(x) - tau, 0) - threshed = threshed.at[:].multiply(snp.sign(x)) - return threshed From ec8686ef7b8413127ca1e36efc8d48ef10f22fcd Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 2 Nov 2023 21:39:03 -0600 Subject: [PATCH 2/8] Add a test --- scico/test/functional/test_tvnorm.py | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 scico/test/functional/test_tvnorm.py diff --git a/scico/test/functional/test_tvnorm.py b/scico/test/functional/test_tvnorm.py new file mode 100644 index 000000000..740601353 --- /dev/null +++ b/scico/test/functional/test_tvnorm.py @@ -0,0 +1,43 @@ +import numpy as np + +import scico.random +from scico import functional, linop, loss, metric +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.optimize.pgm import AcceleratedPGM + + +def test_tvnorm(): + + N = 128 + g = np.linspace(0, 2 * np.pi, N) + x_gt = np.sin(2 * g) + x_gt[x_gt > 0.5] = 0.5 + x_gt[x_gt < -0.5] = -0.5 + σ = 0.02 + noise, key = scico.random.randn(x_gt.shape, seed=0) + y = x_gt + σ * noise + + λ = 5e-2 + + f = loss.SquaredL2Loss(y=y) + g = λ * functional.L1Norm() + C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True) + solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[1e1], + x0=y, + maxiter=50, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), + itstat_options={"display": True, "period": 10}, + ) + x_tvdn = solver.solve() + + h = λ * functional.AnisotropicTVNorm() + solver = AcceleratedPGM( + f=f, g=h, L0=2e2, x0=y, maxiter=50, itstat_options={"display": True, "period": 10} + ) + x_approx = solver.solve() + + assert metric.snr(x_tvdn, x_approx) > 45 From 4f2f189bcedd664b7e4569ca447d119d006e6169 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 10:54:50 -0600 Subject: [PATCH 3/8] Minor changes --- scico/test/functional/test_tvnorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/test/functional/test_tvnorm.py b/scico/test/functional/test_tvnorm.py index 740601353..232665aca 100644 --- a/scico/test/functional/test_tvnorm.py +++ b/scico/test/functional/test_tvnorm.py @@ -18,10 +18,10 @@ def test_tvnorm(): y = x_gt + σ * noise λ = 5e-2 - f = loss.SquaredL2Loss(y=y) - g = λ * functional.L1Norm() + C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True) + g = λ * functional.L1Norm() solver = ADMM( f=f, g_list=[g], From b7427f7fa5cde3b68dc67b1f30bd657715295b7d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 11:47:39 -0600 Subject: [PATCH 4/8] New implementation of TV norm and approximage prox --- scico/functional/_tvnorm.py | 138 ++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 scico/functional/_tvnorm.py diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py new file mode 100644 index 000000000..ed2a2992f --- /dev/null +++ b/scico/functional/_tvnorm.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Anisotropic total variation norm.""" + +import warnings +from typing import Optional, Tuple, Union + +from jax import jit, lax + +from scico import numpy as snp +from scico.linop import FiniteDifference, VerticalStack, CircularConvolve +from scico.numpy import Array, BlockArray, count_nonzero +from scico.numpy.linalg import norm +from scico.numpy.util import no_nan_divide + +from ._functional import Functional +from ._norm import L1Norm + + +class AnisotropicTVNorm(Functional): + r"""The anisotropic total variation (TV) norm. + + The anisotropic total variation (TV) norm computed by + + .. code-block:: python + + ATV = scico.functional.AnisotropicTVNorm() + x_norm = ATV(x) + + is equivalent to + + .. code-block:: python + + C = linop.FiniteDifference(input_shape=x.shape, circular=True) + L1 = functional.L1Norm() + x_norm = L1(C @ x) + + The scaled proximal operator is computed using an approximation that + holds for small scaling parameters :cite:`kamilov-2016-parallel`. + This does not imply that it can only be applied to problems requiring + a small regularization parameter since most proximal algorithms + include an additional algorithm parameter that also plays a role in + the parameter of the proximal operator. For example, in :class:`.PGM` + and :class:`.AcceleratedPGM`, the scaled proximal operator parameter + is the regularization parameter divided by the `L0` algorithm + parameter, and for :class:`.ADMM`, the scaled proximal operator + parameters are the regularization parameters divided by the entries + in the `rho_list` algorithm parameter. + """ + + has_eval = True + has_prox = True + + def __init__(self, ndims: Optional[int] = None): + r""" + Args: + ndims: Number of (trailing) dimensions of the input over + which to apply the finite difference operator. If + ``None``, differences are evaluated along all axes. + """ + self.ndims = ndims + self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter + self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter + self.l1norm = L1Norm() + self.G = None + self.W = None + + def __call__(self, x: Array) -> float: + r"""Compute the anisotropic TV norm of an array.""" + if self.G is None or self.G.shape[1] != x.shape: + if self.ndims is None: + ndims = x.ndim + else: + ndims = self.ndims + axes = tuple(range(ndims)) + self.G = FiniteDifference(x.shape, input_dtype=x.dtype, axes=axes, circular=True) + return snp.sum(snp.abs(self.G @ x)) + + @staticmethod + def _shape(idx: int, ndims: int) -> Tuple: + """Construct a shape tuple. + + Construct a tuple of size `ndims` with all unit entries except + for index `idx`, which has a -1 entry. + """ + return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) + + def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: + r"""Approximate proximal operator of the anisotripic TV norm. + + Approximation of the proximal operator of the anisotropic TV norm, + computed via the method described in :cite:`kamilov-2016-parallel`. + + Args: + v: Input array :math:`\mb{v}`. + lam: Proximal parameter :math:`\lam`. + kwargs: Additional arguments that may be used by derived + classes. + """ + if self.ndims is None: + ndims = v.ndim + else: + ndims = self.ndims + K = 2 * ndims + + if self.W is None or self.W.shape[1] != v.shape: + C0 = VerticalStack( # Stack of lowpass filter operators for each axis + [ + CircularConvolve( + self.h0.reshape(AnisotropicTVNorm._shape(k, ndims)), + v.shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + C1 = VerticalStack( # Stack of highpass filter operators for each axis + [ + CircularConvolve( + self.h1.reshape(AnisotropicTVNorm._shape(k, ndims)), + v.shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + # single-level shift-invariant Haar transform + self.W = VerticalStack((C0, C1), jit=True) + + Wv = self.W @ v + # Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform + Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam)) + return (1.0 / K) * self.W.T @ Wv From c0c96337fcc529a26aef20655587cb6f9a72bd5b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 11:51:20 -0600 Subject: [PATCH 5/8] Clean up --- scico/functional/_tvnorm.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index ed2a2992f..2f1f1088d 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -7,16 +7,11 @@ """Anisotropic total variation norm.""" -import warnings -from typing import Optional, Tuple, Union - -from jax import jit, lax +from typing import Optional, Tuple from scico import numpy as snp -from scico.linop import FiniteDifference, VerticalStack, CircularConvolve -from scico.numpy import Array, BlockArray, count_nonzero -from scico.numpy.linalg import norm -from scico.numpy.util import no_nan_divide +from scico.linop import CircularConvolve, FiniteDifference, VerticalStack +from scico.numpy import Array from ._functional import Functional from ._norm import L1Norm @@ -64,7 +59,7 @@ def __init__(self, ndims: Optional[int] = None): ``None``, differences are evaluated along all axes. """ self.ndims = ndims - self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter + self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter self.l1norm = L1Norm() self.G = None From f251c60dfa3f0dc1f17f4d14c71e0c8075791072 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 11:52:07 -0600 Subject: [PATCH 6/8] Typo fix --- scico/functional/_tvnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index 2f1f1088d..a1a7a0ff7 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -86,7 +86,7 @@ def _shape(idx: int, ndims: int) -> Tuple: return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: - r"""Approximate proximal operator of the anisotripic TV norm. + r"""Approximate proximal operator of the isotropic TV norm. Approximation of the proximal operator of the anisotropic TV norm, computed via the method described in :cite:`kamilov-2016-parallel`. From feb4b7765426860f7ac3420830007bd0aed17b09 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 4 Nov 2023 07:14:36 -0600 Subject: [PATCH 7/8] Minor change --- scico/functional/_tvnorm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index a1a7a0ff7..f0e726cb7 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -73,7 +73,9 @@ def __call__(self, x: Array) -> float: else: ndims = self.ndims axes = tuple(range(ndims)) - self.G = FiniteDifference(x.shape, input_dtype=x.dtype, axes=axes, circular=True) + self.G = FiniteDifference( + x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True + ) return snp.sum(snp.abs(self.G @ x)) @staticmethod From 7fe98b9fe2c140febb6bcf539e9be11f9c87ebe0 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 4 Nov 2023 08:02:58 -0600 Subject: [PATCH 8/8] Add change log entry --- CHANGES.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 7dc048a0a..265b9dd38 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,8 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- +• New functional ``functional.AnisotropicTVNorm`` with proximal operator + approximation. • New integrated Radon/X-ray transform ``linop.XRayTransform``. • Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes