Skip to content

Commit

Permalink
Address PR review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Nov 2, 2023
1 parent 2c99267 commit db7b668
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions examples/scripts/ct_abel_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
`ray.tune` class API is used in this example.
This script is hard-coded to run on CPU only to avoid the large number of
warnings that are emitted when GPU resources are requested but not available,
and due to the difficulty of supressing these warnings in a way that does
not force use of the CPU only. To enable GPU usage, comment out the
`os.environ` statements near the beginning of the script, and change the
value of the "gpu" entry in the `resources` dict from 0 to 1. Note that
two environment variables are set to suppress the warnings because
`JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change
has yet to be correctly implemented
warnings that are emitted when GPU resources are requested but not
available, and due to the difficulty of supressing these warnings in a
way that does not force use of the CPU only. To enable GPU usage, comment
out the `os.environ` statements near the beginning of the script, and
change the value of the "gpu" entry in the `resources` dict from 0 to 1.
Note that two environment variables are set to suppress the warnings
because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but
this change has yet to be correctly implemented
(see [google/jax#6805](https://github.com/google/jax/issues/6805) and
[google/jax#10272](https://github.com/google/jax/pull/10272).
"""
Expand Down Expand Up @@ -82,7 +82,7 @@ def setup(self, config, x_gt, x0, y):
to the evaluation function via the ray object store.
"""
# Get arrays passed by tune call.
self.x_gt, self.x0, self.y = x_gt, x0, y
self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y)
# Set up problem to be solved.
self.A = AbelProjector(self.x_gt.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
Expand Down

0 comments on commit db7b668

Please sign in to comment.