Skip to content

Commit

Permalink
Add docstrings, typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Sep 13, 2024
1 parent 1b215e9 commit bcba849
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
40 changes: 24 additions & 16 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import jax.numpy as jnp
from jax.typing import ArrayLike

import scico.numpy as snp
from scico.numpy.util import is_scalar_equiv
from scico.typing import Shape
from scipy.spatial.transform import Rotation
Expand Down Expand Up @@ -115,17 +116,19 @@ def __init__(
adj_fn=self.back_project,
)

def project(self, im):
def project(self, im: ArrayLike) -> snp.Array:
"""Compute X-ray projection."""
return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)

def back_project(self, y):
def back_project(self, y: ArrayLike) -> snp.Array:
"""Compute X-ray back projection"""
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)

@staticmethod
@partial(jax.jit, static_argnames=["ny"])
def _project(im, x0, dx, y0, ny, angles):
def _project(
im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike
) -> snp.Array:
r"""
Args:
im: Input array, (M, N).
Expand Down Expand Up @@ -155,7 +158,9 @@ def _project(im, x0, dx, y0, ny, angles):

@staticmethod
@partial(jax.jit, static_argnames=["nx"])
def _back_project(y, x0, dx, nx, y0, angles):
def _back_project(
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
) -> ArrayLike:
r"""
Args:
y: Input projection, (num_angles, N).
Expand Down Expand Up @@ -184,7 +189,9 @@ def _back_project(y, x0, dx, nx, y0, angles):
@staticmethod
@partial(jax.jit, static_argnames=["nx"])
@partial(jax.vmap, in_axes=(None, None, None, 0, None))
def _calc_weights(x0, dx, nx, angle, y0):
def _calc_weights(
x0: ArrayLike, dx: ArrayLike, nx: Shape, angle: float, y0: float
) -> snp.Array:
"""
Args:
Expand Down Expand Up @@ -263,28 +270,27 @@ def __init__(
det_shape: Shape of detector.
"""

self.input_shape = input_shape
self.input_shape: Shape = input_shape
self.matrices = matrices
self.det_shape = det_shape
self.output_shape = (len(matrices), *det_shape)

super().__init__(
input_shape=self.input_shape,
input_shape=input_shape,
output_shape=self.output_shape,
eval_fn=self.project,
adj_fn=self.back_project,
)

def project(self, im):
def project(self, im: ArrayLike) -> snp.Array:
"""Compute X-ray projection."""
return XRayTransform3D._project(im, self.matrices, self.det_shape)

def back_project(self, proj):
def back_project(self, proj: ArrayLike) -> snp.Array:
"""Compute X-ray back projection"""
return XRayTransform3D._back_project(proj, self.matrices, self.input_shape)

@staticmethod
def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike:
def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array:
r"""
Args:
im: Input image.
Expand Down Expand Up @@ -312,7 +318,7 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike:
@partial(jax.jit, donate_argnames="proj")
def _project_single(
im: ArrayLike, matrix: ArrayLike, proj: ArrayLike, slice_offset: int = 0
) -> ArrayLike:
) -> snp.Array:
r"""
Args:
im: Input image.
Expand Down Expand Up @@ -359,7 +365,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> A
@partial(jax.jit, donate_argnames="HTy")
def _back_project_single(
y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0
) -> ArrayLike:
) -> snp.Array:
ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights(
HTy.shape, matrix, y.shape, slice_offset
)
Expand All @@ -370,7 +376,9 @@ def _back_project_single(
return HTy

@staticmethod
def _calc_weights(input_shape, matrix, output_shape, slice_offset: int = 0):
def _calc_weights(
input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0
) -> snp.Array:
# pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5)
x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...)
x = x.at[0].add(slice_offset)
Expand Down Expand Up @@ -405,7 +413,7 @@ def matrices_from_euler_angles(
degrees: bool = False,
voxel_spacing: ArrayLike = None,
det_spacing: ArrayLike = None,
):
) -> snp.Array:
"""
Create a set of projection matrices from Euler angles. The
input voxels will undergo the specified rotation and then be
Expand Down Expand Up @@ -450,6 +458,6 @@ def matrices_from_euler_angles(
# add translation to line up the centers
x0 = np.array(input_shape) / 2
t = -np.einsum("vmn,n->vm", matrices, x0) + np.array(output_shape) / 2
matrices = np.concatenate((matrices, t[..., np.newaxis]), axis=2)
matrices = snp.concatenate((matrices, t[..., np.newaxis]), axis=2)

return matrices
13 changes: 12 additions & 1 deletion scico/linop/xray/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def set_astra_gpu_index(idx: Union[int, Sequence[int]]):
def _project_coords(
x_volume: np.ndarray, vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry
) -> np.ndarray:
"""
Transform volume (logical) coordinates into world coordinates based
on ASTRA geometry objects.
Args:
x_volume: (..., 3) vector(s) of volume (AKA logical) coordinates
vol_geom: ASTRA volume geometry object.
proj_geom: ASTRA projection geometry object.
"""
det_shape = (proj_geom["DetectorRowCount"], proj_geom["DetectorColCount"])
x_world = volume_coords_to_world_coords(x_volume, vol_geom=vol_geom)
x_dets = []
Expand Down Expand Up @@ -110,7 +119,7 @@ def project_world_coordinates(
return ind_ij


def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry):
def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray:
"""Convert a volume coordinate into a world coordinate.
Convert a volume coordinate into a world coordinate using ASTRA
Expand All @@ -131,6 +140,7 @@ def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry):


def _volume_index_to_astra_world_2d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray:
"""Convert a 2D volume coordinate into a 2D world coordinate."""
coord = idx[..., [2, 1]] # x:col, y:row,
nx = np.array( # (x, y) order
(
Expand All @@ -150,6 +160,7 @@ def _volume_index_to_astra_world_2d(idx: np.ndarray, vol_geom: VolumeGeometry) -


def _volume_index_to_astra_world_3d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray:
"""Convert a 3D volume coordinate into a 3D world coordinate."""
coord = idx[..., [2, 1, 0]] # x:col, y:row, z:slice
nx = np.array( # (x, y, z) order
(
Expand Down

0 comments on commit bcba849

Please sign in to comment.