From 8d43a149edcca516cb1d098b73a5fb0c7c98fc5b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 12 Aug 2024 15:42:12 -0600 Subject: [PATCH 1/2] Fix docs figure display in furo theme "dark" mode (#547) * Invert figure colours in dark mode * Fix figure font size and line width * Add detector orientation to figures * Improve figure captions * Address review comment --- docs/source/_static/scico.css | 4 + docs/source/pyfigures/cylindgrad.py | 27 +++--- docs/source/pyfigures/polargrad.py | 14 +-- docs/source/pyfigures/spheregrad.py | 27 +++--- docs/source/pyfigures/xray_2d_geom.py | 132 ++++++++++++++++++++++---- docs/source/pyfigures/xray_3d_ang.py | 2 +- scico/linop/xray/__init__.py | 12 +-- scico/linop/xray/astra.py | 7 +- 8 files changed, 164 insertions(+), 61 deletions(-) diff --git a/docs/source/_static/scico.css b/docs/source/_static/scico.css index ea0618c0d..592ccb4ce 100644 --- a/docs/source/_static/scico.css +++ b/docs/source/_static/scico.css @@ -1,5 +1,9 @@ /* furo theme customization */ +body[data-theme="dark"] figure img { + filter: invert(100%); +} + .sidebar-drawer { width: fit-content !important; } diff --git a/docs/source/pyfigures/cylindgrad.py b/docs/source/pyfigures/cylindgrad.py index 9b26220ac..d2d6f9cf2 100644 --- a/docs/source/pyfigures/cylindgrad.py +++ b/docs/source/pyfigures/cylindgrad.py @@ -26,21 +26,24 @@ fig = plot.plt.figure(figsize=(20, 6)) ax = fig.add_subplot(1, 3, 1, projection="3d") ax.quiver(g0, g1, g2, ang[0], ang[1], ang[2], colors=clr, length=0.9) -ax.set_title("Angular local coordinate axis") -ax.set_xlabel("$x$") -ax.set_ylabel("$y$") -ax.set_zlabel("$z$") +ax.set_title("Angular local coordinate axis", fontsize=18) +ax.set_xlabel("$x$", fontsize=15) +ax.set_ylabel("$y$", fontsize=15) +ax.set_zlabel("$z$", fontsize=15) +ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 2, projection="3d") ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9) -ax.set_title("Radial local coordinate axis") -ax.set_xlabel("$x$") -ax.set_ylabel("$y$") -ax.set_zlabel("$z$") +ax.set_title("Radial local coordinate axis", fontsize=18) +ax.set_xlabel("$x$", fontsize=15) +ax.set_ylabel("$y$", fontsize=15) +ax.set_zlabel("$z$", fontsize=15) +ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 3, projection="3d") ax.quiver(g0, g1, g2, axi[0], axi[1], axi[2], colors=clr[0], length=0.9) -ax.set_title("Axial local coordinate axis") -ax.set_xlabel("$x$") -ax.set_ylabel("$y$") -ax.set_zlabel("$z$") +ax.set_title("Axial local coordinate axis", fontsize=18) +ax.set_xlabel("$x$", fontsize=15) +ax.set_ylabel("$y$", fontsize=15) +ax.set_zlabel("$z$", fontsize=15) +ax.tick_params(labelsize=15) fig.tight_layout() fig.show() diff --git a/docs/source/pyfigures/polargrad.py b/docs/source/pyfigures/polargrad.py index 23da202f2..a0a6b119a 100644 --- a/docs/source/pyfigures/polargrad.py +++ b/docs/source/pyfigures/polargrad.py @@ -20,15 +20,17 @@ fig, ax = plot.plt.subplots(nrows=1, ncols=2, figsize=(13, 6)) ax[0].quiver(g0, g1, ang[0], ang[1], clr) -ax[0].set_title("Angular local coordinate axis") -ax[0].set_xlabel("$x$") -ax[0].set_ylabel("$y$") +ax[0].set_title("Angular local coordinate axis", fontsize=16) +ax[0].set_xlabel("$x$", fontsize=14) +ax[0].set_ylabel("$y$", fontsize=14) +ax[0].tick_params(labelsize=14) ax[0].xaxis.set_ticks((-10, -5, 0, 5, 10)) ax[0].yaxis.set_ticks((-10, -5, 0, 5, 10)) ax[1].quiver(g0, g1, rad[0], rad[1], clr) -ax[1].set_title("Radial local coordinate axis") -ax[1].set_xlabel("$x$") -ax[1].set_ylabel("$y$") +ax[1].set_title("Radial local coordinate axis", fontsize=16) +ax[1].set_xlabel("$x$", fontsize=14) +ax[1].set_ylabel("$y$", fontsize=14) +ax[1].tick_params(labelsize=14) ax[1].xaxis.set_ticks((-10, -5, 0, 5, 10)) ax[1].yaxis.set_ticks((-10, -5, 0, 5, 10)) fig.tight_layout() diff --git a/docs/source/pyfigures/spheregrad.py b/docs/source/pyfigures/spheregrad.py index ea2149d92..2f5e8ffc8 100644 --- a/docs/source/pyfigures/spheregrad.py +++ b/docs/source/pyfigures/spheregrad.py @@ -27,21 +27,24 @@ fig = plot.plt.figure(figsize=(20, 6)) ax = fig.add_subplot(1, 3, 1, projection="3d") ax.quiver(g0, g1, g2, azi[0], azi[1], azi[2], colors=clr, length=0.9) -ax.set_title("Azimuthal local coordinate axis") -ax.set_xlabel("$x$") -ax.set_ylabel("$y$") -ax.set_zlabel("$z$") +ax.set_title("Azimuthal local coordinate axis", fontsize=18) +ax.set_xlabel("$x$", fontsize=15) +ax.set_ylabel("$y$", fontsize=15) +ax.set_zlabel("$z$", fontsize=15) +ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 2, projection="3d") ax.quiver(g0, g1, g2, pol[0], pol[1], pol[2], colors=clr, length=0.9) -ax.set_title("Polar local coordinate axis") -ax.set_xlabel("$x$") -ax.set_ylabel("$y$") -ax.set_zlabel("$z$") +ax.set_title("Polar local coordinate axis", fontsize=18) +ax.set_xlabel("$x$", fontsize=15) +ax.set_ylabel("$y$", fontsize=15) +ax.set_zlabel("$z$", fontsize=15) +ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 3, projection="3d") ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9) -ax.set_title("Radial local coordinate axis") -ax.set_xlabel("$x$") -ax.set_ylabel("$y$") -ax.set_zlabel("$z$") +ax.set_title("Radial local coordinate axis", fontsize=18) +ax.set_xlabel("$x$", fontsize=15) +ax.set_ylabel("$y$", fontsize=15) +ax.set_zlabel("$z$", fontsize=15) +ax.tick_params(labelsize=15) fig.tight_layout() fig.show() diff --git a/docs/source/pyfigures/xray_2d_geom.py b/docs/source/pyfigures/xray_2d_geom.py index c1714851f..e44a6a10b 100644 --- a/docs/source/pyfigures/xray_2d_geom.py +++ b/docs/source/pyfigures/xray_2d_geom.py @@ -1,24 +1,32 @@ import numpy as np +import matplotlib as mpl import matplotlib.patches as patches import matplotlib.pyplot as plt +mpl.rcParams["savefig.transparent"] = True + + c = 1.0 / np.sqrt(2.0) e = 1e-2 style = "Simple, tail_width=0.5, head_width=4, head_length=8" fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(21, 7)) + +# all plots for n in range(3): ax[n].set_aspect(1.0) ax[n].set_xlim(-1.1, 1.1) ax[n].set_ylim(-1.1, 1.1) ax[n].set_xticks(np.linspace(-1.0, 1.0, 5)) ax[n].set_yticks(np.linspace(-1.0, 1.0, 5)) - ax[n].tick_params(axis="x", labelsize=12) - ax[n].tick_params(axis="y", labelsize=12) - ax[n].set_xlabel("axis 1", fontsize=14) - ax[n].set_ylabel("axis 0", fontsize=14) + ax[n].tick_params(axis="x", labelsize=14) + ax[n].tick_params(axis="y", labelsize=14) + ax[n].set_xlabel("axis 1", fontsize=16) + ax[n].set_ylabel("axis 0", fontsize=16) + # scico +ax[0].set_title("scico", fontsize=18) plist = [ patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((-c, -c), (-c / 2.0, -c / 2.0), arrowstyle=style, color="r"), @@ -31,32 +39,90 @@ arrowstyle=style, color="r", ), - patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=180, theta2=-45.0, color="b", ls="dotted"), + patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=180, theta2=-45.0, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c - e, -c - e), (c + e, -c + e), arrowstyle=style, color="b"), ] for p in plist: ax[0].add_patch(p) -ax[0].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=14) -ax[0].text(-3 * c / 4 - 0.01, -3 * c / 4 - 0.1, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=14) -ax[0].text(0.03, -0.8, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14) -ax[0].set_title("scico", fontsize=14) + +ax[0].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=16) +ax[0].text(-3 * c / 4 - 0.01, -3 * c / 4 - 0.1, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=16) +ax[0].text(0.03, -0.8, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=16) + +ax[0].plot((1.0, 1.0), (-0.375, 0.375), color="orange", lw=2) +ax[0].arrow( + 0.94, + 0.375, + 0.0, + -0.75, + color="orange", + lw=1.0, + ls="--", + head_width=0.03, + length_includes_head=True, +) +ax[0].text(0.7, 0.0, r"$\theta=0$", color="orange", ha="left", fontsize=16) +ax[0].plot((-0.375, 0.375), (1.0, 1.0), color="orange", lw=2) +ax[0].arrow( + -0.375, + 0.94, + 0.75, + 0.0, + color="orange", + lw=1.0, + ls="--", + head_width=0.03, + length_includes_head=True, +) +ax[0].text(0.0, 0.82, r"$\theta=\frac{\pi}{2}$", color="orange", ha="center", fontsize=16) + # astra +ax[1].set_title("astra", fontsize=18) plist = [ patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color="r"), patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color="r"), - patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", ls="dotted"), + patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color="b"), ] for p in plist: ax[1].add_patch(p) -ax[1].text(0.02, -0.75, r"$\theta=0$", color="r", fontsize=14) -ax[1].text(3 * c / 4 + 0.01, -3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=14) -ax[1].text(0.65, 0.05, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14) -ax[1].set_title("astra", fontsize=14) + +ax[1].text(0.02, -0.75, r"$\theta=0$", color="r", fontsize=16) +ax[1].text(3 * c / 4 + 0.01, -3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=16) +ax[1].text(0.65, 0.05, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=16) + +ax[1].plot((-0.375, 0.375), (1.0, 1.0), color="orange", lw=2) +ax[1].arrow( + -0.375, + 0.94, + 0.75, + 0.0, + color="orange", + lw=1.0, + ls="--", + head_width=0.03, + length_includes_head=True, +) +ax[1].text(0.0, 0.82, r"$\theta=0$", color="orange", ha="center", fontsize=16) +ax[1].plot((-1.0, -1.0), (-0.375, 0.375), color="orange", lw=2) +ax[1].arrow( + -0.94, + -0.375, + 0.0, + 0.75, + color="orange", + lw=1.0, + ls="--", + head_width=0.03, + length_includes_head=True, +) +ax[1].text(-0.9, 0.0, r"$\theta=\frac{\pi}{2}$", color="orange", ha="left", fontsize=16) + # svmbir +ax[2].set_title("svmbir", fontsize=18) plist = [ patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((-c, c), (-c / 2.0, c / 2.0), arrowstyle=style, color="r"), @@ -69,15 +135,43 @@ arrowstyle=style, color="r", ), - patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=45, theta2=180, color="b", ls="dotted"), + patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=45, theta2=180, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c - e, c + e), (c + e, c - e), arrowstyle=style, color="b"), ] for p in plist: ax[2].add_patch(p) -ax[2].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=14) -ax[2].text(-3 * c / 4 + 0.01, 3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=14) -ax[2].text(0.03, 0.75, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14) -ax[2].set_title("svmbir", fontsize=14) +ax[2].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=16) +ax[2].text(-3 * c / 4 + 0.01, 3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=16) +ax[2].text(0.03, 0.75, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=16) + +ax[2].plot((1.0, 1.0), (-0.375, 0.375), color="orange", lw=2) +ax[2].arrow( + 0.94, + 0.375, + 0.0, + -0.75, + color="orange", + lw=1.0, + ls="--", + head_width=0.03, + length_includes_head=True, +) +ax[2].text(0.7, 0.0, r"$\theta=0$", color="orange", ha="left", fontsize=16) + +ax[2].plot((-0.375, 0.375), (-1.0, -1.0), color="orange", lw=2) +ax[2].arrow( + 0.375, + -0.94, + -0.75, + 0.0, + color="orange", + lw=1.0, + ls="--", + head_width=0.03, + length_includes_head=True, +) +ax[2].text(0.0, -0.82, r"$\theta=\frac{\pi}{2}$", color="orange", ha="center", fontsize=16) + fig.tight_layout() fig.show() diff --git a/docs/source/pyfigures/xray_3d_ang.py b/docs/source/pyfigures/xray_3d_ang.py index 5b0c677db..d2d6bf8e3 100644 --- a/docs/source/pyfigures/xray_3d_ang.py +++ b/docs/source/pyfigures/xray_3d_ang.py @@ -25,7 +25,7 @@ patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color="r"), patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color="r"), - patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", ls="dotted"), + patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color="b"), ] for p in plist: diff --git a/scico/linop/xray/__init__.py b/scico/linop/xray/__init__.py index b909f5855..906fb2c50 100644 --- a/scico/linop/xray/__init__.py +++ b/scico/linop/xray/__init__.py @@ -26,14 +26,10 @@ :align: center :include-source: False :show-source-link: False - :caption: Comparison of 2D X-ray projector geometries. The red arrows - are directed towards the detector, which is oriented with pixel - indices ordered in the same direction as clockwise rotation (e.g. - in the "scico" geometry, the :math:`\theta=0` projection - corresponds to row sums ordered from the top to the bottom of the - figure, while the :math:`\theta=\pi` projection - corresponds to row sums ordered from the bottom to the top of the - figure). + :caption: Comparison of 2D X-ray projector geometries. The radial + arrows are directed towards the locations of the corresponding + detectors, with the direction of increasing pixel indices indicated + by the arrows on the dotted lines parallel to the detectors. | diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 294b73244..b7eec9191 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -252,9 +252,10 @@ class XRayTransform3D(LinearOperator): # pragma: no cover :align: center :include-source: False :show-source-link: False - :caption: Red arrows indicate the direction of the beam towards - the detector (orange) and the arrows parallel to the detector - indicate the direction of increasing pixel indices. + :caption: Each radial arrow indicates the direction of the beam + towards the detector (indicated in orange in the "light" + display mode) and the arrow parallel to the detector indicates + the direction of increasing pixel indices. In this case the `z` axis is in the same direction as the vertical/row axis of the detector and its projection corresponds to From 36408f5f435123d895c73bbaac281a50f90371b4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 12 Aug 2024 17:24:41 -0600 Subject: [PATCH 2/2] Bug fix (#548) --- scico/functional/_tvnorm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index e44bb07cc..fdc2dadca 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -147,11 +147,15 @@ def _prox_operators( # Replicate-pad to the right (resulting in a zero after finite differencing) # on all axes subject to finite differencing. pad_width = [(0, 1) if i in axes else (0, 0) for i, s in enumerate(input_shape)] # type: ignore - P = Pad(input_shape, pad_width=pad_width, mode="edge", jit=True) + P = Pad( + input_shape, input_dtype=input_dtype, pad_width=pad_width, mode="edge", jit=True + ) # fused boundary extend and forward transform linop WP = W @ P # crop operation that is inverse of the padding operation - C = Crop(crop_width=pad_width, input_shape=w_input_shape, jit=True) + C = Crop( + crop_width=pad_width, input_shape=w_input_shape, input_dtype=input_dtype, jit=True + ) # fused adjoint transform and crop linop CWT = C @ W.T return WP, CWT, ndims, slce