-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AutoGuideList #1644
Add AutoGuideList #1644
Conversation
Use Pyro implementation as an example to - implement numpyro.infer.autoguide.AutoGuideList - modify numpyro.handlers.block
numpyro/handlers.py
Outdated
site-level metadata returns whether it should be exposed. | ||
:param bool hide_all: whether to hide all sites. | ||
:param bool expose_all: whether to expose all sites. | ||
:param list expose: list of site names to hide. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you only introduce expose
for simplicity? I'm not sure why we need other arguments.
typo: a list of site names to expose.
Here is a summary of the recent changes:
from jax import random
import jax.numpy as jnp
import numpyro
from numpyro import handlers, optim
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoGuideList, AutoNormal, AutoDAIS
sigma = 0.05
def model(x, y=None):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(jnp.ones((2, 1)), 1))
mu = a + x @ b
with numpyro.plate("N", len(x), dim=-2):
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
N = 500
a = 1
b = jnp.asarray([[-0.5], [-1]])
x = random.normal(random.PRNGKey(0), (N, 2))
y = a + x @ b + sigma * random.normal(random.PRNGKey(1), (N,1))
guide = AutoGuideList(model)
guide.append(
AutoNormal(
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=0), expose=["a"]
)
)
)
guide.append(
AutoDAIS(
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), hide=["a"])
)
)
optimiser = numpyro.optim.Adam(step_size=0.1)
svi = SVI(model, guide, optimiser, Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), num_steps=1_000, x=x, y=y)
---------------------------------------------------------------------------
UnexpectedTracerError Traceback (most recent call last)
Cell In[7], line 42
39 optimiser = numpyro.optim.Adam(step_size=0.1)
40 svi = SVI(model, guide, optimiser, Trace_ELBO())
---> 42 svi_result = svi.run(random.PRNGKey(0), num_steps=1_000, x=x, y=y)
File ~/numpyro/numpyro/infer/svi.py:365, in SVI.run(self, rng_key, num_steps, progress_bar, stable_update, init_state, init_params, *args, **kwargs)
363 batch = max(num_steps // 20, 1)
364 for i in t:
--> 365 svi_state, loss = jit(body_fn)(svi_state, None)
366 losses.append(loss)
367 if i % batch == 0:
[... skipping hidden 12 frame]
File ~/numpyro/numpyro/infer/svi.py:353, in SVI.run.<locals>.body_fn(svi_state, _)
351 svi_state, loss = self.stable_update(svi_state, *args, **kwargs)
352 else:
--> 353 svi_state, loss = self.update(svi_state, *args, **kwargs)
354 return svi_state, loss
File ~/numpyro/numpyro/infer/svi.py:266, in SVI.update(self, svi_state, *args, **kwargs)
254 rng_key, rng_key_step = random.split(svi_state.rng_key)
255 loss_fn = _make_loss_fn(
256 self.loss,
257 rng_key_step,
(...)
264 mutable_state=svi_state.mutable_state,
265 )
--> 266 (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
267 loss_fn, svi_state.optim_state
268 )
269 return SVIState(optim_state, mutable_state, rng_key), loss_val
File ~/numpyro/numpyro/optim.py:80, in _NumPyroOptim.eval_and_update(self, fn, state)
65 """
66 Performs an optimization step for the objective function `fn`.
67 For most optimizers, the update is performed based on the gradient
(...)
77 :return: a pair of the output of objective function and the new optimizer state.
78 """
79 params = self.get_params(state)
---> 80 (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
81 return (out, aux), self.update(grads, state)
[... skipping hidden 8 frame]
File ~/numpyro/numpyro/infer/svi.py:61, in _make_loss_fn.<locals>.loss_fn(params)
58 return result["loss"], result["mutable_state"]
59 else:
60 return (
---> 61 elbo.loss(
62 rng_key, params, model, guide, *args, **kwargs, **static_kwargs
63 ),
64 None,
65 )
File ~/numpyro/numpyro/infer/elbo.py:68, in ELBO.loss(self, rng_key, param_map, model, guide, *args, **kwargs)
53 def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
54 """
55 Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
56
(...)
66 :return: negative of the Evidence Lower Bound (ELBO) to be minimized.
67 """
---> 68 return self.loss_with_mutable_state(
69 rng_key, param_map, model, guide, *args, **kwargs
70 )["loss"]
File ~/numpyro/numpyro/infer/elbo.py:167, in Trace_ELBO.loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
164 # Return (-elbo) since by convention we do gradient descent on a loss and
165 # the ELBO is a lower bound that needs to be maximized.
166 if self.num_particles == 1:
--> 167 elbo, mutable_state = single_particle_elbo(rng_key)
168 return {"loss": -elbo, "mutable_state": mutable_state}
169 else:
File ~/numpyro/numpyro/infer/elbo.py:129, in Trace_ELBO.loss_with_mutable_state.<locals>.single_particle_elbo(rng_key)
127 seeded_model = seed(model, model_seed)
128 seeded_guide = seed(guide, guide_seed)
--> 129 guide_log_density, guide_trace = log_density(
130 seeded_guide, args, kwargs, param_map
131 )
132 mutable_params = {
133 name: site["value"]
134 for name, site in guide_trace.items()
135 if site["type"] == "mutable"
136 }
137 params.update(mutable_params)
File ~/numpyro/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
50 """
51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
52 latent values ``params``.
(...)
59 :return: log of joint density and a corresponding model trace
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():
File ~/numpyro/numpyro/handlers.py:172, in trace.get_trace(self, *args, **kwargs)
164 def get_trace(self, *args, **kwargs):
165 """
166 Run the wrapped callable and return the recorded trace.
167
(...)
170 :return: `OrderedDict` containing the execution trace.
171 """
--> 172 self(*args, **kwargs)
173 return self.trace
File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/numpyro/numpyro/infer/autoguide.py:271, in AutoGuideList.__call__(self, *args, **kwargs)
269 result = {}
270 for part in self._guides:
--> 271 result.update(part(*args, **kwargs))
272 return result
File ~/numpyro/numpyro/infer/autoguide.py:669, in AutoContinuous.__call__(self, *args, **kwargs)
665 if self.prototype_trace is None:
666 # run model to inspect the model structure
667 self._setup_prototype(*args, **kwargs)
--> 669 latent = self._sample_latent(*args, **kwargs)
671 # unpack continuous latent samples
672 result = {}
File ~/numpyro/numpyro/infer/autoguide.py:953, in AutoDAIS._sample_latent(self, *args, **kwargs)
950 return (z, v, log_factor), None
952 v_0 = eps[-1] # note the return value of scan doesn't depend on eps[-1]
--> 953 (z, _, log_factor), _ = jax.lax.scan(scan_body, (z_0, v_0, 0.0), (eps, betas))
955 numpyro.factor("{}_factor".format(self.prefix), log_factor)
957 return z
[... skipping hidden 9 frame]
File ~/numpyro/numpyro/infer/autoguide.py:944, in AutoDAIS._sample_latent.<locals>.scan_body(carry, eps_beta)
942 z_half = z_prev + v_prev * eta * inv_mass_matrix
943 q_grad = (1.0 - beta) * grad(base_z_dist.log_prob)(z_half)
--> 944 p_grad = beta * grad(log_density)(z_half)
945 v_hat = v_prev + eta * (q_grad + p_grad)
946 z = z_half + v_hat * eta * inv_mass_matrix
[... skipping hidden 10 frame]
File ~/numpyro/numpyro/infer/autoguide.py:876, in AutoDAIS._sample_latent.<locals>.log_density(x)
874 x_unpack = self._unpack_latent(x)
875 with numpyro.handlers.block():
--> 876 return -self._potential_fn(x_unpack)
File ~/numpyro/numpyro/infer/util.py:291, in potential_energy(model, model_args, model_kwargs, params, enum)
287 substituted_model = substitute(
288 model, substitute_fn=partial(_unconstrain_reparam, params)
289 )
290 # no param is needed for log_density computation because we already substitute
--> 291 log_joint, model_trace = log_density_(
292 substituted_model, model_args, model_kwargs, {}
293 )
294 return -log_joint
File ~/numpyro/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
50 """
51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
52 latent values ``params``.
(...)
59 :return: log of joint density and a corresponding model trace
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():
File ~/numpyro/numpyro/handlers.py:172, in trace.get_trace(self, *args, **kwargs)
164 def get_trace(self, *args, **kwargs):
165 """
166 Run the wrapped callable and return the recorded trace.
167
(...)
170 :return: `OrderedDict` containing the execution trace.
171 """
--> 172 self(*args, **kwargs)
173 return self.trace
File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (3 times)]
File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[7], line 13, in model(x, y)
12 def model(x, y=None):
---> 13 a = numpyro.sample("a", dist.Normal(0, 1))
14 b = numpyro.sample("b", dist.Normal(jnp.ones((2, 1)), 1))
15 mu = a + x @ b
File ~/numpyro/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
207 initial_msg = {
208 "type": "sample",
209 "name": name,
(...)
218 "infer": {} if infer is None else infer,
219 }
221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
223 return msg["value"]
File ~/numpyro/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):
File ~/numpyro/numpyro/handlers.py:796, in seed.process_message(self, msg)
793 if (msg["kwargs"]["rng_key"] is not None) or (msg["value"] is not None):
794 # no need to create a new key when value is available
795 return
--> 796 self.rng_key, rng_key_sample = random.split(self.rng_key)
797 msg["kwargs"]["rng_key"] = rng_key_sample
File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/random.py:240, in split(key, num)
229 def split(key: KeyArray, num: int = 2) -> KeyArray:
230 """Splits a PRNG key into `num` new keys by adding a leading axis.
231
232 Args:
(...)
238 An array-like object of `num` new PRNG keys.
239 """
--> 240 key, wrapped = _check_prng_key(key)
241 return _return_prng_keys(wrapped, _split(key, num))
File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/random.py:80, in _check_prng_key(key)
75 if config.jax_enable_custom_prng:
76 warnings.warn(
77 'Raw arrays as random keys to jax.random functions are deprecated. '
78 'Assuming valid threefry2x32 key for now.',
79 FutureWarning)
---> 80 return prng.random_wrap(key, impl=default_prng_impl()), True
81 else:
82 raise TypeError(f'unexpected PRNG key type {type(key)}')
File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/prng.py:864, in random_wrap(base_arr, impl)
862 def random_wrap(base_arr, *, impl):
863 _check_prng_key_data(impl, base_arr)
--> 864 return random_wrap_p.bind(base_arr, impl=impl)
[... skipping hidden 2 frame]
File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py:1579, in DynamicJaxprTracer._assert_live(self)
1577 def _assert_live(self) -> None:
1578 if not self._trace.main.jaxpr_stack: # type: ignore
-> 1579 raise core.escaped_tracer_error(self, None)
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scan_body at /Users/tare/numpyro/numpyro/infer/autoguide.py:937 traced for scan.
------------------------------
The leaked intermediate value was created on line /Users/tare/numpyro/numpyro/handlers.py:796 (process_message).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/Users/tare/numpyro/numpyro/primitives.py:105 (__call__)
/var/folders/1s/d7zzs88n6n709f0w22bql1100000gn/T/ipykernel_23159/1142004465.py:14 (model)
/Users/tare/numpyro/numpyro/primitives.py:222 (sample)
/Users/tare/numpyro/numpyro/primitives.py:47 (apply_stack)
/Users/tare/numpyro/numpyro/handlers.py:796 (process_message)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError |
Re AutoDais: I'm not sure if AutoGuideList can work with DAIS because the guide requires model evaluation. I'm also not sure if the algorithm can be used in AutoGuideList. @martinjankowiak do you have some ideas? Hopefully the tracer issue can be resolved after #1657. I can take a stab at it later of the week. But I think it does not block your PR. |
i think it'd be quite difficult/awkward since |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me overall! Thank you. I just have a small comment.
) | ||
else: | ||
result.update( | ||
part.sample_posterior(rng_key, params, sample_shape=sample_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you usage signature of other algorithm to be the same as AutoDelta? That signature is more general I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce the chance of misunderstanding, can you please clarify what you mean.
The reason for the if-else clause is that AutoDelta
and AutoSemiDAIS
differ from the other autoguides when it comes to the signature of sample_posterior()
:
AutoDelta
andAutoSemiDAIS
havesample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs)
- other autoguides have
sample_posterior(self, rng_key, params, sample_shape=())
I suppose we could break things if we change the signature of sample_posterior()
since sample_shape
might have been passed as a positional argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since sample_shape might have been passed as a positional argument
Good point! Maybe we can keep your current implementation here as-is and adding *,
before sample_shape in other classes. After the next release, we'll make the signature consistent. Users of the next release will keep informed of the change because they are no longer able to use sample_shape
as a positional argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable to me! I'm happy to create another PR to address this if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thank you!
numpyro/infer/autoguide.py
Outdated
result = {} | ||
for part in self._guides: | ||
if isinstance(part, numpyro.infer.autoguide.AutoDelta) or isinstance( | ||
part, numpyro.infer.autoguide.AutoSemiDAIS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably you can raise ValueError in the constructor if the guide is either AutoDAIS or AutoSemiDAIS. I'm not sure if the algorithm gives reasonable results under AutoGuideList.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also AutoSurrogateLikelihoodDAIS
- append() raises ValueError when guide is AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
@fehiepsi Sorry for the additional commit; I realized I hadn't added |
Thank you, @tare! |
Hello! This PR eventually aims to address #1638.
Using the Pyro implementation (please see guides.py and block_messenger.py) as an example I implemented
numpyro.infer.autoguide.AutoGuideList
and modifiednumpyro.handlers.block
with the following main differences:add
alias fromnumpyro.infer.autoguide.AutoGuideList
block
handler differs from Pyro as I thought changing the positions of the current parameters could break thingsHere is a simple example using
AutoGuideList
I think I encountered the same problem as mentioned here. To overcome (I think?) the problem, I wrapped the NumPyro model in a
seed
handler before wrapping it in ablock
handler. I don't know whether there is a better way.Should we implement
AutoCallable
from Pyro?No tests have been implemented yet.