Skip to content

Commit

Permalink
Merge branch 'main' into tv-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg authored Nov 3, 2023
2 parents a7e82ba + 27e2aec commit b5e8fc9
Show file tree
Hide file tree
Showing 91 changed files with 946 additions and 643 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------

• New integrated Radon/X-ray transform ``linop.XRayTransform``.
• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
to ``XRayTransform``.
• Rename ``AbelProjector`` to ``AbelTransform``.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19.

Expand Down
2 changes: 1 addition & 1 deletion data
Submodule data updated 43 files
+ docs/figures/scico-tomo-overview.png
+3 −3 notebooks/ct_abel_tv_admm.ipynb
+13 −14 notebooks/ct_abel_tv_admm_tune.ipynb
+8 −12 notebooks/ct_astra_3d_tv_admm.ipynb
+4 −4 notebooks/ct_astra_modl_train_foam2.ipynb
+5 −6 notebooks/ct_astra_noreg_pcg.ipynb
+3 −3 notebooks/ct_astra_odp_train_foam2.ipynb
+7 −9 notebooks/ct_astra_tv_admm.ipynb
+10 −12 notebooks/ct_astra_weighted_tv_admm.ipynb
+12 −12 notebooks/ct_fan_svmbir_ppp_bm3d_admm_prox.ipynb
+740 −0 notebooks/ct_multi_cs_tv_admm.ipynb
+714 −0 notebooks/ct_multi_tv_admm.ipynb
+180 −225 notebooks/ct_projector_comparison.ipynb
+5 −5 notebooks/ct_svmbir_ppp_bm3d_admm_cg.ipynb
+11 −11 notebooks/ct_svmbir_ppp_bm3d_admm_prox.ipynb
+7 −8 notebooks/ct_svmbir_tv_multi.ipynb
+1 −4 notebooks/deconv_circ_tv_admm.ipynb
+3 −5 notebooks/deconv_microscopy_allchn_tv_admm.ipynb
+0 −1 notebooks/deconv_microscopy_tv_admm.ipynb
+1 −1 notebooks/deconv_modl_train_foam1.ipynb
+4 −4 notebooks/deconv_odp_train_foam1.ipynb
+1 −3 notebooks/deconv_ppp_bm3d_admm.ipynb
+1 −3 notebooks/deconv_ppp_bm3d_pgm.ipynb
+1 −3 notebooks/deconv_ppp_bm4d_admm.ipynb
+1 −3 notebooks/deconv_ppp_dncnn_admm.ipynb
+1 −3 notebooks/deconv_ppp_dncnn_padmm.ipynb
+1 −4 notebooks/deconv_tv_admm.ipynb
+0 −3 notebooks/deconv_tv_admm_tune.ipynb
+2 −5 notebooks/deconv_tv_padmm.ipynb
+2 −5 notebooks/demosaic_ppp_bm3d_admm.ipynb
+2 −4 notebooks/denoise_dncnn_universal.ipynb
+0 −3 notebooks/denoise_l1tv_admm.ipynb
+0 −3 notebooks/denoise_tv_admm.ipynb
+1 −4 notebooks/denoise_tv_multi.ipynb
+0 −8 notebooks/denoise_tv_pgm.ipynb
+2 −4 notebooks/diffusercam_tv_admm.ipynb
+6 −2 notebooks/index.ipynb
+3 −4 notebooks/sparsecode_admm.ipynb
+2 −4 notebooks/sparsecode_conv_admm.ipynb
+2 −4 notebooks/sparsecode_conv_md_admm.ipynb
+3 −4 notebooks/sparsecode_pgm.ipynb
+5 −11 notebooks/sparsecode_poisson_pgm.ipynb
+1 −4 notebooks/superres_ppp_dncnn_admm.ipynb
3 changes: 2 additions & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ Computed Tomography
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison

examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm

Deconvolution
^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions docs/source/inverse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ SCICO provides the :class:`.Operator` and :class:`.LinearOperator`
classes, which may be subclassed by users, in order to implement the
forward operator, :math:`A`. It also has several built-in operators,
most of which are linear, e.g., finite convolutions, discrete Fourier
transforms, optical propagators, Abel transforms, and Radon
transforms. For example,
transforms, optical propagators, Abel transforms, and X-ray transforms
(the same as Radon transforms in 2D). For example,

.. code:: python
input_shape = (512, 512)
angles = np.linspace(0, 2 * np.pi, 180, endpoint=False)
channels = 512
A = scico.linop.radon_svmbir.ParallelBeamProjector(input_shape, angles, channels)
A = scico.linop.xray.svmbir.XRayTransform(input_shape, angles, channels)
defines a tomographic projection operator.

Expand Down
22 changes: 17 additions & 5 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,25 @@ via interfaces to the `bm3d <https://pypi.org/project/bm3d/>`__ and
when the full benefits of JAX-based code are required.


Tomographic Projectors
----------------------

The :class:`.radon_svmbir.TomographicProjector` class is implemented
Tomographic Projectors/Radon Transforms
---------------------------------------

Note that the tomographic projections that are frequently referred
to as Radon transforms are referred to as X-ray transforms in SCICO.
While the Radon transform is far more well-known than the X-ray
transform, which is the same as the Radon transform for projections
in two dimensions, these two transform differ in higher numbers of
dimensions, and it is the X-ray transform that is the appropriate
mathematical model for beam attenuation based imaging in three or
more dimensions.

SCICO includes three different implementations of X-ray transforms.
Of these, :class:`.linop.XRayTransform` is an integral component of
SCICO, while the other two depend on external packages.
The :class:`.xray.svmbir.XRayTransform` class is implemented
via an interface to the `svmbir
<https://svmbir.readthedocs.io/en/latest/>`__ package. The
:class:`.radon_astra.TomographicProjector` class is implemented via an
:class:`.xray.astra.XRayTransform` class is implemented via an
interface to the `ASTRA toolbox
<https://www.astra-toolbox.com/>`__. This toolbox does provide some
GPU acceleration support, but efficiency is expected to be lower than
Expand Down
2 changes: 1 addition & 1 deletion examples/examples_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
astra-toolbox
colour_demosaicing
xdesign>=0.5.5
ray[tune]>=2.0.0
ray[tune,train]>=2.5.0
hyperopt
bm3d>=4.0.0
bm4d>=4.2.2
7 changes: 5 additions & 2 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ Computed Tomography
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`ct_projector_comparison.py <ct_projector_comparison.py>`_
X-ray Projector Comparison

X-ray Transform Comparison
`ct_multi_cs_tv_admm.py <ct_multi_cs_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram)
`ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)

Deconvolution
^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ct_abel_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_circular_phantom
from scico.linop.abel import AbelProjector
from scico.linop.abel import AbelTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand All @@ -40,7 +40,7 @@
"""
Set up the forward operator and create a test measurement.
"""
A = AbelProjector(x_gt.shape)
A = AbelTransform(x_gt.shape)
y = A @ x_gt
np.random.seed(12345)
y = y + np.random.normal(size=y.shape).astype(np.float32)
Expand All @@ -57,7 +57,7 @@
better performance than isotropic TV for this problem, is used here.
"""
f = loss.SquaredL2Loss(y=y, A=A)
λ = 2.35e1 # L1 norm regularization parameter
λ = 2.35e1 # ℓ1 norm regularization parameter
g = λ * functional.L1Norm() # Note the use of anisotropic TV
C = linop.FiniteDifference(input_shape=x_gt.shape)

Expand Down
27 changes: 13 additions & 14 deletions examples/scripts/ct_abel_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
`ray.tune` class API is used in this example.
This script is hard-coded to run on CPU only to avoid the large number of
warnings that are emitted when GPU resources are requested but not available,
and due to the difficulty of supressing these warnings in a way that does
not force use of the CPU only. To enable GPU usage, comment out the
`os.environ` statements near the beginning of the script, and change the
value of the "gpu" entry in the `resources` dict from 0 to 1. Note that
two environment variables are set to suppress the warnings because
`JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change
has yet to be correctly implemented
warnings that are emitted when GPU resources are requested but not
available, and due to the difficulty of supressing these warnings in a
way that does not force use of the CPU only. To enable GPU usage, comment
out the `os.environ` statements near the beginning of the script, and
change the value of the "gpu" entry in the `resources` dict from 0 to 1.
Note that two environment variables are set to suppress the warnings
because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but
this change has yet to be correctly implemented
(see [google/jax#6805](https://github.com/google/jax/issues/6805) and
[google/jax#10272](https://github.com/google/jax/pull/10272).
"""
Expand All @@ -34,12 +34,11 @@

import numpy as np

import jax

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_circular_phantom
from scico.linop.abel import AbelProjector
from scico.linop.abel import AbelTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.ray import tune

Expand All @@ -53,7 +52,7 @@
"""
Set up the forward operator and create a test measurement.
"""
A = AbelProjector(x_gt.shape)
A = AbelTransform(x_gt.shape)
y = A @ x_gt
np.random.seed(12345)
y = y + np.random.normal(size=y.shape).astype(np.float32)
Expand Down Expand Up @@ -82,10 +81,10 @@ def setup(self, config, x_gt, x0, y):
this case). The remaining parameters are objects that are passed
to the evaluation function via the ray object store.
"""
# Put main arrays on jax device.
self.x_gt, self.x0, self.y = jax.device_put([x_gt, x0, y])
# Get arrays passed by tune call.
self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y)
# Set up problem to be solved.
self.A = AbelProjector(self.x_gt.shape)
self.A = AbelTransform(self.x_gt.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.C = linop.FiniteDifference(input_shape=self.x_gt.shape)
self.reset_config(config)
Expand Down
23 changes: 10 additions & 13 deletions examples/scripts/ct_astra_3d_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,42 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, $C$ is
a 3D finite difference operator, and $\mathbf{x}$ is the desired
image.
where $A$ is the X-ray transform (the CT forward projection operator),
$\mathbf{y}$ is the sinogram, $C$ is a 3D finite difference operator,
and $\mathbf{x}$ is the desired image.
"""


import numpy as np

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_tangle_phantom
from scico.linop.radon_astra import TomographicProjector
from scico.linop.xray.astra import XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

"""
Create a ground truth image and projector.
"""

Nx = 128
Ny = 256
Nz = 64

tangle = create_tangle_phantom(Nx, Ny, Nz)
tangle = jax.device_put(tangle)
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))

n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = TomographicProjector(
tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles
) # Radon transform operator
A = XRayTransform(tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles) # CT projection operator
y = A @ tangle # sinogram


"""
Set up ADMM solver object.
"""
λ = 2e0 # L1 norm regularization parameter
λ = 2e0 # ℓ2,1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
Expand All @@ -82,6 +77,7 @@
itstat_options={"display": True, "period": 5},
)


"""
Run the solver.
"""
Expand All @@ -95,6 +91,7 @@
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))
)


"""
Show the recovered image.
"""
Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.radon_astra import TomographicProjector
from scico.linop.xray.astra import XRayTransform

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -81,12 +81,12 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = TomographicProjector(
A = XRayTransform(
input_shape=(N, N),
detector_spacing=1,
det_count=N,
angles=angles,
) # Radon transform operator
) # CT projection operator
A = (1.0 / N) * A # normalized


Expand Down Expand Up @@ -168,7 +168,7 @@
stats_object_ini = None

checkpoint_files = []
for (dirpath, dirnames, filenames) in os.walk(workdir2):
for dirpath, dirnames, filenames in os.walk(workdir2):
checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"]

if len(checkpoint_files) > 0:
Expand Down
11 changes: 5 additions & 6 deletions examples/scripts/ct_astra_noreg_pcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,37 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 \;,$$
where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, and
$\mathbf{x}$ is the reconstructed image.
where $A$ is the X-ray transform (the CT forward projection operator),
$\mathbf{y}$ is the sinogram, and $\mathbf{x}$ is the reconstructed image.
"""

from time import time

import numpy as np

import jax
import jax.numpy as jnp

from xdesign import Foam, discrete_phantom

from scico import loss, plot
from scico.linop import CircularConvolve
from scico.linop.radon_astra import TomographicProjector
from scico.linop.xray.astra import XRayTransform
from scico.solver import cg

"""
Create a ground truth image.
"""
N = 256 # phantom size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU
x_gt = jnp.array(x_gt) # convert to jax type


"""
Configure a CT projection operator and generate synthetic measurements.
"""
n_projection = N # matches the phantom size so this is not few-view CT
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = 1 / N * TomographicProjector(x_gt.shape, 1, N, angles) # Radon transform operator
A = 1 / N * XRayTransform(x_gt.shape, 1, N, angles) # CT projection operator
y = A @ x_gt # sinogram


Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.radon_astra import TomographicProjector
from scico.linop.xray.astra import XRayTransform

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -85,12 +85,12 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = TomographicProjector(
A = XRayTransform(
input_shape=(N, N),
detector_spacing=1,
det_count=N,
angles=angles,
) # Radon transform operator
) # CT projection operator
A = (1.0 / N) * A # normalized


Expand Down
Loading

0 comments on commit b5e8fc9

Please sign in to comment.