Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ensure_on_device #37

Merged
merged 3 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion notebooks/ct_abel_tv_admm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
"outputs": [],
"source": [
"f = loss.SquaredL2Loss(y=y, A=A)\n",
"λ = 2.35e1 # L1 norm regularization parameter\n",
"λ = 2.35e1 # ℓ1 norm regularization parameter\n",
"g = λ * functional.L1Norm() # Note the use of anisotropic TV\n",
"C = linop.FiniteDifference(input_shape=x_gt.shape)"
]
Expand Down
21 changes: 10 additions & 11 deletions notebooks/ct_abel_tv_admm_tune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
"`ray.tune` class API is used in this example.\n",
"\n",
"This script is hard-coded to run on CPU only to avoid the large number of\n",
"warnings that are emitted when GPU resources are requested but not available,\n",
"and due to the difficulty of supressing these warnings in a way that does\n",
"not force use of the CPU only. To enable GPU usage, comment out the\n",
"`os.environ` statements near the beginning of the script, and change the\n",
"value of the \"gpu\" entry in the `resources` dict from 0 to 1. Note that\n",
"two environment variables are set to suppress the warnings because\n",
"`JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change\n",
"has yet to be correctly implemented\n",
"warnings that are emitted when GPU resources are requested but not\n",
"available, and due to the difficulty of supressing these warnings in a\n",
"way that does not force use of the CPU only. To enable GPU usage, comment\n",
"out the `os.environ` statements near the beginning of the script, and\n",
"change the value of the \"gpu\" entry in the `resources` dict from 0 to 1.\n",
"Note that two environment variables are set to suppress the warnings\n",
"because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but\n",
"this change has yet to be correctly implemented\n",
"(see [google/jax#6805](https://github.com/google/jax/issues/6805) and\n",
"[google/jax#10272](https://github.com/google/jax/pull/10272)."
]
Expand Down Expand Up @@ -49,7 +49,6 @@
"\n",
"import numpy as np\n",
"\n",
"import jax\n",
"\n",
"import scico.numpy as snp\n",
"from scico import functional, linop, loss, metric, plot\n",
Expand Down Expand Up @@ -177,8 +176,8 @@
" this case). The remaining parameters are objects that are passed\n",
" to the evaluation function via the ray object store.\n",
" \"\"\"\n",
" # Put main arrays on jax device.\n",
" self.x_gt, self.x0, self.y = jax.device_put([x_gt, x0, y])\n",
" # Get arrays passed by tune call.\n",
" self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y)\n",
" # Set up problem to be solved.\n",
" self.A = AbelProjector(self.x_gt.shape)\n",
" self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n",
Expand Down
8 changes: 3 additions & 5 deletions notebooks/ct_astra_3d_tv_admm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"\n",
"from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
"\n",
"import scico.numpy as snp\n",
"from scico import functional, linop, loss, metric, plot\n",
"from scico.examples import create_tangle_phantom\n",
"from scico.linop.radon_astra import TomographicProjector\n",
Expand Down Expand Up @@ -76,8 +75,7 @@
"Ny = 256\n",
"Nz = 64\n",
"\n",
"tangle = create_tangle_phantom(Nx, Ny, Nz)\n",
"tangle = jax.device_put(tangle)\n",
"tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))\n",
"\n",
"n_projection = 10 # number of projections\n",
"angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles\n",
Expand Down Expand Up @@ -110,7 +108,7 @@
},
"outputs": [],
"source": [
"λ = 2e0 # L1 norm regularization parameter\n",
"λ = 2e0 # ℓ2,1 norm regularization parameter\n",
"ρ = 5e0 # ADMM penalty parameter\n",
"maxiter = 25 # number of ADMM iterations\n",
"cg_tol = 1e-4 # CG relative tolerance\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/ct_astra_modl_train_foam2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@
"stats_object_ini = None\n",
"\n",
"checkpoint_files = []\n",
"for (dirpath, dirnames, filenames) in os.walk(workdir2):\n",
"for dirpath, dirnames, filenames in os.walk(workdir2):\n",
" checkpoint_files = [fn for fn in filenames if str.split(fn, \"_\")[0] == \"checkpoint\"]\n",
"\n",
"if len(checkpoint_files) > 0:\n",
Expand Down
3 changes: 1 addition & 2 deletions notebooks/ct_astra_noreg_pcg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"\n",
"import numpy as np\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from xdesign import Foam, discrete_phantom\n",
Expand Down Expand Up @@ -75,7 +74,7 @@
"source": [
"N = 256 # phantom size\n",
"x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\n",
"x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU"
"x_gt = jnp.array(x_gt) # convert to jax type"
]
},
{
Expand Down
6 changes: 2 additions & 4 deletions notebooks/ct_astra_tv_admm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"\n",
"from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
"from xdesign import Foam, discrete_phantom\n",
"\n",
Expand Down Expand Up @@ -75,7 +73,7 @@
"N = 512 # phantom size\n",
"np.random.seed(1234)\n",
"x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\n",
"x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU"
"x_gt = snp.array(x_gt) # convert to jax type"
]
},
{
Expand Down Expand Up @@ -130,7 +128,7 @@
},
"outputs": [],
"source": [
"λ = 2e0 # L1 norm regularization parameter\n",
"λ = 2e0 # ℓ1 norm regularization parameter\n",
"ρ = 5e0 # ADMM penalty parameter\n",
"maxiter = 25 # number of ADMM iterations\n",
"cg_tol = 1e-4 # CG relative tolerance\n",
Expand Down
8 changes: 3 additions & 5 deletions notebooks/ct_astra_weighted_tv_admm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"\n",
"from xdesign import Soil, discrete_phantom\n",
"\n",
"import scico.numpy as snp\n",
Expand Down Expand Up @@ -79,7 +77,7 @@
"x_gt = discrete_phantom(Soil(porosity=0.80), size=384)\n",
"x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))\n",
"x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values\n",
"x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU"
"x_gt = snp.array(x_gt) # convert to jax type"
]
},
{
Expand Down Expand Up @@ -149,7 +147,7 @@
"counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt))\n",
"counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1\n",
"y = -1 / 𝛼 * np.log(counts / Io)\n",
"y = jax.device_put(y) # convert back to float32"
"y = snp.array(y) # convert back to float32 as a jax array"
]
},
{
Expand Down Expand Up @@ -1409,7 +1407,7 @@
"source": [
"lambda_weighted = 5e1\n",
"\n",
"weights = jax.device_put(counts / Io)\n",
"weights = snp.array(counts / Io)\n",
"f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))\n",
"\n",
"admm_weighted = ADMM(\n",
Expand Down
18 changes: 9 additions & 9 deletions notebooks/ct_fan_svmbir_ppp_bm3d_admm_prox.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import svmbir\n",
"from matplotlib.ticker import MaxNLocator\n",
Expand Down Expand Up @@ -236,7 +234,7 @@
"id": "74b2d0bb",
"metadata": {},
"source": [
"Push arrays to device."
"Convert numpy arrays to jax arrays."
]
},
{
Expand All @@ -254,8 +252,10 @@
},
"outputs": [],
"source": [
"y_fan, x0_fan, weights_fan = jax.device_put([y_fan, x_mrf_fan, weights_fan])\n",
"x0_parallel = jax.device_put(x_mrf_parallel)"
"y_fan = snp.array(y_fan)\n",
"x0_fan = snp.array(x_mrf_fan)\n",
"weights_fan = snp.array(weights_fan)\n",
"x0_parallel = snp.array(x_mrf_parallel)"
]
},
{
Expand Down Expand Up @@ -334,7 +334,7 @@
" x0=x0_fan,\n",
" maxiter=20,\n",
" subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n",
" itstat_options={\"display\": True},\n",
" itstat_options={\"display\": True, \"period\": 5},\n",
")\n",
"solver_extloss_parallel = ADMM(\n",
" f=None,\n",
Expand All @@ -344,7 +344,7 @@
" x0=x0_parallel,\n",
" maxiter=20,\n",
" subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n",
" itstat_options={\"display\": True},\n",
" itstat_options={\"display\": True, \"period\": 5},\n",
")"
]
},
Expand Down Expand Up @@ -814,7 +814,7 @@
" fig=fig,\n",
" ax=ax[0],\n",
")\n",
"ax[0].set_ylim([5e-3, 1e0])\n",
"ax[0].set_ylim([5e-3, 5e0])\n",
"ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"plot.plot(\n",
" snp.vstack((hist_extloss_fan.Prml_Rsdl, hist_extloss_fan.Dual_Rsdl)).T,\n",
Expand All @@ -825,7 +825,7 @@
" fig=fig,\n",
" ax=ax[1],\n",
")\n",
"ax[1].set_ylim([5e-3, 1e0])\n",
"ax[1].set_ylim([5e-3, 5e0])\n",
"ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"fig.show()"
]
Expand Down
26 changes: 10 additions & 16 deletions notebooks/ct_projector_comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"X-ray Projector Comparison\n",
"==========================\n",
"\n",
"This example compares SCICO's native X-ray projection algorithm\n",
"to that of the ASTRA Toolbox."
"This example compares SCICO's native X-ray projection algorithm to that\n",
"of the ASTRA Toolbox."
]
},
{
Expand Down Expand Up @@ -65,12 +65,9 @@
"outputs": [],
"source": [
"N = 512\n",
"\n",
"\n",
"det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))\n",
"\n",
"x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\n",
"x_gt = jax.device_put(x_gt)"
"x_gt = jnp.array(x_gt)"
]
},
{
Expand Down Expand Up @@ -99,7 +96,6 @@
"num_angles = 500\n",
"angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)\n",
"\n",
"\n",
"timer = Timer()\n",
"\n",
"projectors = {}\n",
Expand Down Expand Up @@ -192,7 +188,7 @@
"10% slower when both are run the CPU.\n",
"\n",
"On our server, using the GPU:\n",
"``` \n",
"```\n",
"Label Accum. Current\n",
"-------------------------------------------\n",
"astra_avg_proj 4.62e-02 s Stopped\n",
Expand Down Expand Up @@ -317,7 +313,7 @@
"\n",
"y = np.zeros(H.output_shape, dtype=np.float32)\n",
"y[num_angles // 3, det_count // 2] = 1.0\n",
"y = jax.device_put(y)\n",
"y = jnp.array(y)\n",
"\n",
"HTys = {}\n",
"for name, H in projectors.items():\n",
Expand Down Expand Up @@ -369,11 +365,10 @@
"source": [
"Display back projection timing results.\n",
"\n",
"On our server, the SCICO back projection is slow\n",
"the first time it is run, probably due to JIT overhead.\n",
"After the first run, it is an order of magnitude\n",
"faster than ASTRA when both are run on the GPU,\n",
"and about three times faster when both are run on the CPU.\n",
"On our server, the SCICO back projection is slow the first time it is\n",
"run, probably due to JIT overhead. After the first run, it is an order of\n",
"magnitude faster than ASTRA when both are run on the GPU, and about three\n",
"times faster when both are run on the CPU.\n",
"\n",
"On our server, using the GPU:\n",
"```\n",
Expand Down Expand Up @@ -433,8 +428,7 @@
"id": "88a22207",
"metadata": {},
"source": [
"Show back projections of a single detector element,\n",
"i.e., a line."
"Show back projections of a single detector element, i.e., a line."
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions notebooks/ct_svmbir_ppp_bm3d_admm_cg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import svmbir\n",
"from xdesign import Foam, discrete_phantom\n",
Expand Down Expand Up @@ -205,7 +203,9 @@
},
"outputs": [],
"source": [
"y, x0, weights = jax.device_put([y, x_mrf, weights])\n",
"y = snp.array(y)\n",
"x0 = snp.array(x_mrf)\n",
"weights = snp.array(weights)\n",
"\n",
"ρ = 15 # ADMM penalty parameter\n",
"σ = density * 0.18 # denoiser sigma\n",
Expand Down
16 changes: 8 additions & 8 deletions notebooks/ct_svmbir_ppp_bm3d_admm_prox.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import svmbir\n",
"from matplotlib.ticker import MaxNLocator\n",
Expand Down Expand Up @@ -201,7 +199,7 @@
"id": "5a19be40",
"metadata": {},
"source": [
"Push arrays to device."
"Convert numpy arrays to jax arrays."
]
},
{
Expand All @@ -219,7 +217,9 @@
},
"outputs": [],
"source": [
"y, x0, weights = jax.device_put([y, x_mrf, weights])"
"y = snp.array(y)\n",
"x0 = snp.array(x_mrf)\n",
"weights = snp.array(weights)"
]
},
{
Expand Down Expand Up @@ -286,7 +286,7 @@
" x0=x0,\n",
" maxiter=20,\n",
" subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n",
" itstat_options={\"display\": True},\n",
" itstat_options={\"display\": True, \"period\": 5},\n",
")"
]
},
Expand Down Expand Up @@ -509,7 +509,7 @@
" x0=x0,\n",
" maxiter=20,\n",
" subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n",
" itstat_options={\"display\": True},\n",
" itstat_options={\"display\": True, \"period\": 5},\n",
")"
]
},
Expand Down Expand Up @@ -803,7 +803,7 @@
" fig=fig,\n",
" ax=ax[0],\n",
")\n",
"ax[0].set_ylim([5e-3, 1e0])\n",
"ax[0].set_ylim([5e-3, 5e0])\n",
"ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"plot.plot(\n",
" snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T,\n",
Expand All @@ -814,7 +814,7 @@
" fig=fig,\n",
" ax=ax[1],\n",
")\n",
"ax[1].set_ylim([5e-3, 1e0])\n",
"ax[1].set_ylim([5e-3, 5e0])\n",
"ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"fig.show()"
]
Expand Down
Loading