diff --git a/examples/scripts/ct_multi_cs_tv_admm.py b/examples/scripts/ct_multi_cs_tv_admm.py index f7fcfcc10..35e8f6abc 100644 --- a/examples/scripts/ct_multi_cs_tv_admm.py +++ b/examples/scripts/ct_multi_cs_tv_admm.py @@ -23,8 +23,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -38,8 +36,7 @@ """ N = 512 # phantom size np.random.seed(1234) -x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) +x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """ diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py index 8ed284c8d..6ab77cffe 100644 --- a/examples/scripts/ct_multi_tv_admm.py +++ b/examples/scripts/ct_multi_tv_admm.py @@ -22,8 +22,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -37,8 +35,7 @@ """ N = 512 # phantom size np.random.seed(1234) -x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) +x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """ diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py index 6aa3474a7..4b77eeb07 100644 --- a/examples/scripts/ct_tv_admm.py +++ b/examples/scripts/ct_tv_admm.py @@ -24,8 +24,6 @@ import numpy as np -import jax - from mpl_toolkits.axes_grid1 import make_axes_locatable from xdesign import Foam, discrete_phantom @@ -40,8 +38,7 @@ """ N = 512 # phantom size np.random.seed(1234) -x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """