Skip to content

Commit

Permalink
Merge pull request #2 from shnaqvi/tv-norm-alt-ver
Browse files Browse the repository at this point in the history
Alternative implementation of TV norm
  • Loading branch information
shnaqvi authored Nov 5, 2023
2 parents b5e8fc9 + 7fe98b9 commit 2963523
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 109 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------

• New functional ``functional.AnisotropicTVNorm`` with proximal operator
approximation.
• New integrated Radon/X-ray transform ``linop.XRayTransform``.
• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
Expand Down
4 changes: 2 additions & 2 deletions scico/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
L21Norm,
NuclearNorm,
L1MinusL2Norm,
TV2DNorm,
)
from ._tvnorm import AnisotropicTVNorm
from ._indicator import NonNegativeIndicator, L2BallIndicator
from ._denoiser import BM3D, BM4D, DnCNN
from ._dist import SetDistance, SquaredSetDistance


__all__ = [
"AnisotropicTVNorm",
"Functional",
"ScaledFunctional",
"SeparableFunctional",
Expand All @@ -47,7 +48,6 @@
"BM3D",
"BM4D",
"DnCNN",
"TV2DNorm",
]

# Imported items in __all__ appear to originate in top-level functional module
Expand Down
107 changes: 0 additions & 107 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,110 +479,3 @@ def prox(
svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)
svdS = snp.maximum(0, svdS - lam)
return svdU @ snp.diag(svdS) @ svdV


class TV2DNorm(Functional):
r"""The anisotropic total variation (TV) norm.
For a :math:`M \times N` matrix, :math:`\mb{A}`, by default,
.. math::
\norm{\mb{A}}_{\text{TV}} = \sum_{n=1}^N \sum_{m=1}^M
\abs{\nabla{A}_{m,n}} \;.
The proximal operator of this norm is currently only defined for 2
dimensional data.
For `BlockArray` inputs, the TV norm follows the reduction rules
described in :class:`BlockArray`.
"""

has_eval = True
has_prox = True

def __init__(self, tau: float = 1.0):
r"""
Args:
tau: Parameter :math:`\tau` in the norm definition.
"""
self.tau = tau

def __call__(self, x: Union[Array, BlockArray]) -> float:
r"""Return the TV norm of an array."""
y = 0
gradOp = FiniteDifference(x.shape, input_dtype=x.dtype, circular=True)
grads = gradOp @ x
for g in grads:
y += snp.abs(g)
return self.tau * snp.sum(y)

def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Proximal operator of the TV norm.
Approximate the proximal operator of the anisotropic TV norm via
the method described in :cite:`kamilov-2016-parallel`.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lam`.
kwargs: Additional arguments that may be used by derived
classes.
"""
D = 2
K = 2 * D
thresh = snp.sqrt(2) * K * self.tau * lam

y = snp.zeros_like(v)
for ax in range(2):
y = y.at[:].add(
self.iht2(
self.ht2_shrink(v, axis=ax, shift=False, thresh=thresh), axis=ax, shift=False
)
)
y = y.at[:].add(
self.iht2(
self.ht2_shrink(v, axis=ax, shift=True, thresh=thresh), axis=ax, shift=True
)
)
y = y.at[:].divide(K)
return y

def ht2_shrink(self, x, axis, shift, thresh):
r"""Forward Discrete Haar Wavelet transform in 2D"""
w = snp.zeros_like(x)
C = 1 / snp.sqrt(2)
if shift:
x = snp.roll(x, -1, axis=axis)

m = x.shape[axis] // 2
if not axis:
w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :]))
w = w.at[m:, :].set(self.shrink(C * (x[1::2, :] - x[::2, :]), thresh))
else:
w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2]))
w = w.at[:, m:].set(self.shrink(C * (x[:, 1::2] - x[:, ::2]), thresh))
return w

def iht2(self, w, axis, shift):
r"""Inverse Discrete Haar Wavelet transform in 2D"""
y = snp.zeros_like(w)
C = 1 / snp.sqrt(2)
m = w.shape[axis] // 2
if not axis:
y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :]))
y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :]))
else:
y = y.at[:, ::2].set(C * (w[:, :m] - w[:, m:]))
y = y.at[:, 1::2].set(C * (w[:, :m] + w[:, m:]))

if shift:
y = snp.roll(y, 1, axis)
return y

def shrink(self, x, tau):
r"""Wavelet shrinkage operator"""
threshed = snp.maximum(snp.abs(x) - tau, 0)
threshed = threshed.at[:].multiply(snp.sign(x))
return threshed
135 changes: 135 additions & 0 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Anisotropic total variation norm."""

from typing import Optional, Tuple

from scico import numpy as snp
from scico.linop import CircularConvolve, FiniteDifference, VerticalStack
from scico.numpy import Array

from ._functional import Functional
from ._norm import L1Norm


class AnisotropicTVNorm(Functional):
r"""The anisotropic total variation (TV) norm.
The anisotropic total variation (TV) norm computed by
.. code-block:: python
ATV = scico.functional.AnisotropicTVNorm()
x_norm = ATV(x)
is equivalent to
.. code-block:: python
C = linop.FiniteDifference(input_shape=x.shape, circular=True)
L1 = functional.L1Norm()
x_norm = L1(C @ x)
The scaled proximal operator is computed using an approximation that
holds for small scaling parameters :cite:`kamilov-2016-parallel`.
This does not imply that it can only be applied to problems requiring
a small regularization parameter since most proximal algorithms
include an additional algorithm parameter that also plays a role in
the parameter of the proximal operator. For example, in :class:`.PGM`
and :class:`.AcceleratedPGM`, the scaled proximal operator parameter
is the regularization parameter divided by the `L0` algorithm
parameter, and for :class:`.ADMM`, the scaled proximal operator
parameters are the regularization parameters divided by the entries
in the `rho_list` algorithm parameter.
"""

has_eval = True
has_prox = True

def __init__(self, ndims: Optional[int] = None):
r"""
Args:
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
"""
self.ndims = ndims
self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter
self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter
self.l1norm = L1Norm()
self.G = None
self.W = None

def __call__(self, x: Array) -> float:
r"""Compute the anisotropic TV norm of an array."""
if self.G is None or self.G.shape[1] != x.shape:
if self.ndims is None:
ndims = x.ndim
else:
ndims = self.ndims
axes = tuple(range(ndims))
self.G = FiniteDifference(
x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True
)
return snp.sum(snp.abs(self.G @ x))

@staticmethod
def _shape(idx: int, ndims: int) -> Tuple:
"""Construct a shape tuple.
Construct a tuple of size `ndims` with all unit entries except
for index `idx`, which has a -1 entry.
"""
return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1)

def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
r"""Approximate proximal operator of the isotropic TV norm.
Approximation of the proximal operator of the anisotropic TV norm,
computed via the method described in :cite:`kamilov-2016-parallel`.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lam`.
kwargs: Additional arguments that may be used by derived
classes.
"""
if self.ndims is None:
ndims = v.ndim
else:
ndims = self.ndims
K = 2 * ndims

if self.W is None or self.W.shape[1] != v.shape:
C0 = VerticalStack( # Stack of lowpass filter operators for each axis
[
CircularConvolve(
self.h0.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
C1 = VerticalStack( # Stack of highpass filter operators for each axis
[
CircularConvolve(
self.h1.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
# single-level shift-invariant Haar transform
self.W = VerticalStack((C0, C1), jit=True)

Wv = self.W @ v
# Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform
Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam))
return (1.0 / K) * self.W.T @ Wv
43 changes: 43 additions & 0 deletions scico/test/functional/test_tvnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

import scico.random
from scico import functional, linop, loss, metric
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.optimize.pgm import AcceleratedPGM


def test_tvnorm():

N = 128
g = np.linspace(0, 2 * np.pi, N)
x_gt = np.sin(2 * g)
x_gt[x_gt > 0.5] = 0.5
x_gt[x_gt < -0.5] = -0.5
σ = 0.02
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise

λ = 5e-2
f = loss.SquaredL2Loss(y=y)

C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True)
g = λ * functional.L1Norm()
solver = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[1e1],
x0=y,
maxiter=50,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}),
itstat_options={"display": True, "period": 10},
)
x_tvdn = solver.solve()

h = λ * functional.AnisotropicTVNorm()
solver = AcceleratedPGM(
f=f, g=h, L0=2e2, x0=y, maxiter=50, itstat_options={"display": True, "period": 10}
)
x_approx = solver.solve()

assert metric.snr(x_tvdn, x_approx) > 45

0 comments on commit 2963523

Please sign in to comment.