Skip to content

Commit

Permalink
Add JAX 3D X-ray CT projector (#529)
Browse files Browse the repository at this point in the history
Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
Michael-T-McCann and bwohlberg authored Sep 16, 2024
1 parent f46ff25 commit 5d5b2b9
Show file tree
Hide file tree
Showing 19 changed files with 1,117 additions and 198 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ venv.bak/
# Rope project settings
.ropeproject

# VS Code settings
.vscode/

# mkdocs documentation
/site

Expand Down
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Version 0.0.6 (unreleased)
----------------------------

• Significant changes to ``linop.xray.astra`` API.
• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``.
• New functional ``functional.IsotropicTVNorm`` and faster implementation
of ``functional.AnisotropicTVNorm``.
• New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``,
Expand Down
4 changes: 3 additions & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ Computed Tomography
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison
examples/ct_projector_comparison_2d
examples/ct_projector_comparison_3d
examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm

Deconvolution
Expand Down
4 changes: 4 additions & 0 deletions examples/scriptcheck.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ for f in $SCRIPTPATH/scripts/*.py; do
printf "%s\n" skipped
continue
fi
if [ $SKIP_GPU -eq 1 ] && grep -q 'ct_projector_comparison_3d' <<< $f; then
printf "%s\n" skipped
continue
fi

# Create temporary copy of script with all algorithm maxiter values set
# to small number and final input statements commented out.
Expand Down
6 changes: 4 additions & 2 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ Computed Tomography
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`ct_projector_comparison.py <ct_projector_comparison.py>`_
X-ray Transform Comparison
`ct_projector_comparison_2d.py <ct_projector_comparison_2d.py>`_
2D X-ray Transform Comparison
`ct_projector_comparison_3d.py <ct_projector_comparison_3d.py>`_
3D X-ray Transform Comparison
`ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)

Expand Down
6 changes: 2 additions & 4 deletions examples/scripts/ct_multi_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir
from scico.linop.xray import XRayTransform2D, astra, svmbir
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand All @@ -54,9 +54,7 @@
"svmbir": svmbir.XRayTransform(
x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing
), # svmbir
"scico": XRayTransform(
Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing)
), # scico
"scico": XRayTransform2D((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@


r"""
X-ray Transform Comparison
==========================
2D X-ray Transform Comparison
=============================
This example compares SCICO's native X-ray transform algorithm
This example compares SCICO's native 2D X-ray transform algorithm
to that of the ASTRA toolbox.
"""

Expand All @@ -22,7 +22,7 @@

import scico.linop.xray.astra as astra
from scico import plot
from scico.linop import Parallel2dProjector, XRayTransform
from scico.linop.xray import XRayTransform2D
from scico.util import Timer

"""
Expand All @@ -46,7 +46,7 @@

projectors = {}
timer.start("scico_init")
projectors["scico"] = XRayTransform(Parallel2dProjector((N, N), angles))
projectors["scico"] = XRayTransform2D((N, N), angles)
timer.stop("scico_init")

timer.start("astra_init")
Expand Down
200 changes: 200 additions & 0 deletions examples/scripts/ct_projector_comparison_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.


r"""
3D X-ray Transform Comparison
=============================
This example shows how to define a SCICO native 3D X-ray transform using
ASTRA toolbox conventions and vice versa.
"""

import numpy as np

import jax
import jax.numpy as jnp

import scico.linop.xray.astra as astra
from scico import plot
from scico.examples import create_block_phantom
from scico.linop.xray import XRayTransform3D
from scico.util import ContextTimer, Timer

"""
Create a ground truth image and set detector dimensions.
"""
N = 64
# use rectangular volume to check whether axes are handled correctly
in_shape = (N + 1, N + 2, N + 3)
x = create_block_phantom(in_shape)
x = jnp.array(x)

# use rectangular detector to check whether axes are handled correctly
out_shape = (N, N + 1)


"""
Set up SCICO projection.
"""
num_angles = 3


rot_X = 90.0 - 16.0
rot_Y = np.linspace(0, 180, num_angles, endpoint=False)
angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)
matrices = XRayTransform3D.matrices_from_euler_angles(
in_shape, out_shape, "XY", angles, degrees=True
)

"""
Specify geometry using SCICO conventions and project.
"""
num_repeats = 3

timer_scico = Timer()
with ContextTimer(timer_scico, "init"):
H_scico = XRayTransform3D(in_shape, matrices, out_shape)

with ContextTimer(timer_scico, "first_fwd"):
y_scico = H_scico @ x
jax.block_until_ready(y_scico)

with ContextTimer(timer_scico, "avg_fwd"):
for _ in range(num_repeats):
y_scico = H_scico @ x
jax.block_until_ready(y_scico)
timer_scico.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_scico, "first_back"):
HTy_scico = H_scico.T @ y_scico

with ContextTimer(timer_scico, "avg_back"):
for _ in range(num_repeats):
HTy_scico = H_scico.T @ y_scico
jax.block_until_ready(HTy_scico)
timer_scico.td["avg_back"] /= num_repeats


"""
Convert SCICO geometry to ASTRA and project.
"""

vectors_from_scico = astra.convert_from_scico_geometry(in_shape, matrices, out_shape)

timer_astra = Timer()
with ContextTimer(timer_astra, "init"):
H_astra_from_scico = astra.XRayTransform3D(
input_shape=in_shape, det_count=out_shape, vectors=vectors_from_scico
)

with ContextTimer(timer_astra, "first_fwd"):
y_astra_from_scico = H_astra_from_scico @ x
jax.block_until_ready(y_astra_from_scico)

with ContextTimer(timer_astra, "avg_fwd"):
for _ in range(num_repeats):
y_astra_from_scico = H_astra_from_scico @ x
jax.block_until_ready(y_astra_from_scico)
timer_astra.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_astra, "first_back"):
HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico

with ContextTimer(timer_astra, "avg_back"):
for _ in range(num_repeats):
HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico
jax.block_until_ready(HTy_astra_from_scico)
timer_astra.td["avg_back"] /= num_repeats


"""
Specify geometry with ASTRA conventions and project.
"""

angles = np.random.rand(num_angles) * 180 # random projection angles
det_spacing = [1.0, 1.0]
vectors = astra.angle_to_vector(det_spacing, angles)

H_astra = astra.XRayTransform3D(input_shape=in_shape, det_count=out_shape, vectors=vectors)

y_astra = H_astra @ x
HTy_astra = H_astra.T @ y_astra


"""
Convert ASTRA geometry to SCICO and project.
"""

P_from_astra = astra._astra_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom)
H_scico_from_astra = XRayTransform3D(in_shape, P_from_astra, out_shape)

y_scico_from_astra = H_scico_from_astra @ x
HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra


"""
Print timing results.
"""
print(f"init astra {timer_astra.td['init']:.2e} s")
print(f"init scico {timer_scico.td['init']:.2e} s")
print("")
for tstr in ("first", "avg"):
for dstr in ("fwd", "back"):
for timer, pstr in zip((timer_astra, timer_scico), ("astra", "scico")):
print(f"{tstr:5s} {dstr:4s} {pstr} {timer.td[tstr + '_' + dstr]:.2e} s")
print()


"""
Show projections.
"""
fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))
plot.imview(y_scico[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico[1], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico[2], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra_from_scico[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra_from_scico[:, 1], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra_from_scico[:, 2], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))
plot.imview(y_scico_from_astra[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico_from_astra[1], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico_from_astra[2], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra[:, 1], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra[:, 2], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()


"""
Show back projections.
"""
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))
plot.imview(HTy_scico[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0])
plot.imview(
HTy_astra_from_scico[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1]
)
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))
plot.imview(
HTy_scico_from_astra[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0]
)
plot.imview(HTy_astra[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()


input("\nWaiting for input to close figures and exit")
4 changes: 2 additions & 2 deletions examples/scripts/ct_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray import Parallel2dProjector, XRayTransform
from scico.linop.xray import XRayTransform2D
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand All @@ -46,7 +46,7 @@
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles
A = XRayTransform(Parallel2dProjector((N, N), angles)) # CT projection operator
A = XRayTransform2D((N, N), angles) # CT projection operator
y = A @ x_gt # sinogram


Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ Computed Tomography
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_projector_comparison.py
- ct_projector_comparison_2d.py
- ct_projector_comparison_3d.py
- ct_multi_tv_admm.py

Deconvolution
Expand Down
Loading

0 comments on commit 5d5b2b9

Please sign in to comment.