Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoGuideList #1644

Merged
merged 9 commits into from
Oct 3, 2023
Merged

Add AutoGuideList #1644

merged 9 commits into from
Oct 3, 2023

Conversation

tare
Copy link
Contributor

@tare tare commented Sep 22, 2023

Hello! This PR eventually aims to address #1638.

Using the Pyro implementation (please see guides.py and block_messenger.py) as an example I implemented numpyro.infer.autoguide.AutoGuideList and modified numpyro.handlers.block with the following main differences:

  • discarded the deprecated add alias from numpyro.infer.autoguide.AutoGuideList
  • the order of parameters in the constructor of the block handler differs from Pyro as I thought changing the positions of the current parameters could break things

Here is a simple example usingAutoGuideList

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro import handlers, optim
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer.autoguide import AutoNormal, AutoDelta, AutoGuideList

N, dim = 3000, 3
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = jnp.arange(1.0, dim + 1.0)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))


def model(data, labels):
    coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
    offset = numpyro.sample("offset", dist.Uniform(-1, 1))
    logits = offset + jnp.sum(coefs * data, axis=-1)
    return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)


adam = optim.Adam(0.01)
rng_key_init = random.PRNGKey(0)

guide = AutoGuideList(model)
guide.append(
    AutoNormal(
        numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"])
    )
)
guide.append(
    AutoDelta(
        numpyro.handlers.block(
            numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"]
        )
    )
)

svi = SVI(model, guide, adam, Trace_ELBO())
svi_result = svi.run(rng_key_init, 2000, data, labels)

params = svi_result.params

predictive = Predictive(guide, params=params, num_samples=1000)
posterior_samples = predictive(random.PRNGKey(2), data=data)

I think I encountered the same problem as mentioned here. To overcome (I think?) the problem, I wrapped the NumPyro model in a seed handler before wrapping it in a block handler. I don't know whether there is a better way.

Should we implement AutoCallable from Pyro?

No tests have been implemented yet.

Use Pyro implementation as an example to
- implement numpyro.infer.autoguide.AutoGuideList
- modify numpyro.handlers.block
@tare tare marked this pull request as draft September 22, 2023 13:40
@tare tare changed the title Add initial AutoGuideList implementation Add AutoGuideList Sep 22, 2023
site-level metadata returns whether it should be exposed.
:param bool hide_all: whether to hide all sites.
:param bool expose_all: whether to expose all sites.
:param list expose: list of site names to hide.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you only introduce expose for simplicity? I'm not sure why we need other arguments.

typo: a list of site names to expose.

@tare
Copy link
Contributor Author

tare commented Sep 29, 2023

Here is a summary of the recent changes:

  • Simplified the signature and logic of the block handler
  • Changed the implementation of sample_posterior because AutoDelta differs from the other autoguides when it comes to the signature of sample_posterior
  • Added __getitem__, __len__, and __iter__ magic methods
  • Added various tests related to block
  • Added various tests related to AutoGuideList

AutoDAIS does not seem to behave well in AutoGuideList when block and seed are used

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro import handlers, optim
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoGuideList, AutoNormal, AutoDAIS

sigma = 0.05

def model(x, y=None):
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.Normal(jnp.ones((2, 1)), 1))
    mu = a + x @ b
    with numpyro.plate("N", len(x), dim=-2):
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

N = 500
a = 1
b = jnp.asarray([[-0.5], [-1]])
x = random.normal(random.PRNGKey(0), (N, 2))
y = a + x @ b + sigma * random.normal(random.PRNGKey(1), (N,1))

guide = AutoGuideList(model)
guide.append(
    AutoNormal(
        numpyro.handlers.block(
            numpyro.handlers.seed(model, rng_seed=0), expose=["a"]
        )
    )
)
guide.append(
    AutoDAIS(
        numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), hide=["a"])
    )
)

optimiser = numpyro.optim.Adam(step_size=0.1)
svi = SVI(model, guide, optimiser, Trace_ELBO())

svi_result = svi.run(random.PRNGKey(0), num_steps=1_000, x=x, y=y)
---------------------------------------------------------------------------
UnexpectedTracerError                     Traceback (most recent call last)
Cell In[7], line 42
     39 optimiser = numpyro.optim.Adam(step_size=0.1)
     40 svi = SVI(model, guide, optimiser, Trace_ELBO())
---> 42 svi_result = svi.run(random.PRNGKey(0), num_steps=1_000, x=x, y=y)

File ~/numpyro/numpyro/infer/svi.py:365, in SVI.run(self, rng_key, num_steps, progress_bar, stable_update, init_state, init_params, *args, **kwargs)
    363 batch = max(num_steps // 20, 1)
    364 for i in t:
--> 365     svi_state, loss = jit(body_fn)(svi_state, None)
    366     losses.append(loss)
    367     if i % batch == 0:

    [... skipping hidden 12 frame]

File ~/numpyro/numpyro/infer/svi.py:353, in SVI.run.<locals>.body_fn(svi_state, _)
    351     svi_state, loss = self.stable_update(svi_state, *args, **kwargs)
    352 else:
--> 353     svi_state, loss = self.update(svi_state, *args, **kwargs)
    354 return svi_state, loss

File ~/numpyro/numpyro/infer/svi.py:266, in SVI.update(self, svi_state, *args, **kwargs)
    254 rng_key, rng_key_step = random.split(svi_state.rng_key)
    255 loss_fn = _make_loss_fn(
    256     self.loss,
    257     rng_key_step,
   (...)
    264     mutable_state=svi_state.mutable_state,
    265 )
--> 266 (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
    267     loss_fn, svi_state.optim_state
    268 )
    269 return SVIState(optim_state, mutable_state, rng_key), loss_val

File ~/numpyro/numpyro/optim.py:80, in _NumPyroOptim.eval_and_update(self, fn, state)
     65 """
     66 Performs an optimization step for the objective function `fn`.
     67 For most optimizers, the update is performed based on the gradient
   (...)
     77 :return: a pair of the output of objective function and the new optimizer state.
     78 """
     79 params = self.get_params(state)
---> 80 (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
     81 return (out, aux), self.update(grads, state)

    [... skipping hidden 8 frame]

File ~/numpyro/numpyro/infer/svi.py:61, in _make_loss_fn.<locals>.loss_fn(params)
     58     return result["loss"], result["mutable_state"]
     59 else:
     60     return (
---> 61         elbo.loss(
     62             rng_key, params, model, guide, *args, **kwargs, **static_kwargs
     63         ),
     64         None,
     65     )

File ~/numpyro/numpyro/infer/elbo.py:68, in ELBO.loss(self, rng_key, param_map, model, guide, *args, **kwargs)
     53 def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
     54     """
     55     Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
     56 
   (...)
     66     :return: negative of the Evidence Lower Bound (ELBO) to be minimized.
     67     """
---> 68     return self.loss_with_mutable_state(
     69         rng_key, param_map, model, guide, *args, **kwargs
     70     )["loss"]

File ~/numpyro/numpyro/infer/elbo.py:167, in Trace_ELBO.loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
    164 # Return (-elbo) since by convention we do gradient descent on a loss and
    165 # the ELBO is a lower bound that needs to be maximized.
    166 if self.num_particles == 1:
--> 167     elbo, mutable_state = single_particle_elbo(rng_key)
    168     return {"loss": -elbo, "mutable_state": mutable_state}
    169 else:

File ~/numpyro/numpyro/infer/elbo.py:129, in Trace_ELBO.loss_with_mutable_state.<locals>.single_particle_elbo(rng_key)
    127 seeded_model = seed(model, model_seed)
    128 seeded_guide = seed(guide, guide_seed)
--> 129 guide_log_density, guide_trace = log_density(
    130     seeded_guide, args, kwargs, param_map
    131 )
    132 mutable_params = {
    133     name: site["value"]
    134     for name, site in guide_trace.items()
    135     if site["type"] == "mutable"
    136 }
    137 params.update(mutable_params)

File ~/numpyro/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
     50 """
     51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
     52 latent values ``params``.
   (...)
     59 :return: log of joint density and a corresponding model trace
     60 """
     61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
     63 log_joint = jnp.zeros(())
     64 for site in model_trace.values():

File ~/numpyro/numpyro/handlers.py:172, in trace.get_trace(self, *args, **kwargs)
    164 def get_trace(self, *args, **kwargs):
    165     """
    166     Run the wrapped callable and return the recorded trace.
    167 
   (...)
    170     :return: `OrderedDict` containing the execution trace.
    171     """
--> 172     self(*args, **kwargs)
    173     return self.trace

File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/numpyro/numpyro/infer/autoguide.py:271, in AutoGuideList.__call__(self, *args, **kwargs)
    269 result = {}
    270 for part in self._guides:
--> 271     result.update(part(*args, **kwargs))
    272 return result

File ~/numpyro/numpyro/infer/autoguide.py:669, in AutoContinuous.__call__(self, *args, **kwargs)
    665 if self.prototype_trace is None:
    666     # run model to inspect the model structure
    667     self._setup_prototype(*args, **kwargs)
--> 669 latent = self._sample_latent(*args, **kwargs)
    671 # unpack continuous latent samples
    672 result = {}

File ~/numpyro/numpyro/infer/autoguide.py:953, in AutoDAIS._sample_latent(self, *args, **kwargs)
    950     return (z, v, log_factor), None
    952 v_0 = eps[-1]  # note the return value of scan doesn't depend on eps[-1]
--> 953 (z, _, log_factor), _ = jax.lax.scan(scan_body, (z_0, v_0, 0.0), (eps, betas))
    955 numpyro.factor("{}_factor".format(self.prefix), log_factor)
    957 return z

    [... skipping hidden 9 frame]

File ~/numpyro/numpyro/infer/autoguide.py:944, in AutoDAIS._sample_latent.<locals>.scan_body(carry, eps_beta)
    942 z_half = z_prev + v_prev * eta * inv_mass_matrix
    943 q_grad = (1.0 - beta) * grad(base_z_dist.log_prob)(z_half)
--> 944 p_grad = beta * grad(log_density)(z_half)
    945 v_hat = v_prev + eta * (q_grad + p_grad)
    946 z = z_half + v_hat * eta * inv_mass_matrix

    [... skipping hidden 10 frame]

File ~/numpyro/numpyro/infer/autoguide.py:876, in AutoDAIS._sample_latent.<locals>.log_density(x)
    874 x_unpack = self._unpack_latent(x)
    875 with numpyro.handlers.block():
--> 876     return -self._potential_fn(x_unpack)

File ~/numpyro/numpyro/infer/util.py:291, in potential_energy(model, model_args, model_kwargs, params, enum)
    287 substituted_model = substitute(
    288     model, substitute_fn=partial(_unconstrain_reparam, params)
    289 )
    290 # no param is needed for log_density computation because we already substitute
--> 291 log_joint, model_trace = log_density_(
    292     substituted_model, model_args, model_kwargs, {}
    293 )
    294 return -log_joint

File ~/numpyro/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
     50 """
     51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
     52 latent values ``params``.
   (...)
     59 :return: log of joint density and a corresponding model trace
     60 """
     61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
     63 log_joint = jnp.zeros(())
     64 for site in model_trace.values():

File ~/numpyro/numpyro/handlers.py:172, in trace.get_trace(self, *args, **kwargs)
    164 def get_trace(self, *args, **kwargs):
    165     """
    166     Run the wrapped callable and return the recorded trace.
    167 
   (...)
    170     :return: `OrderedDict` containing the execution trace.
    171     """
--> 172     self(*args, **kwargs)
    173     return self.trace

File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (3 times)]

File ~/numpyro/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[7], line 13, in model(x, y)
     12 def model(x, y=None):
---> 13     a = numpyro.sample("a", dist.Normal(0, 1))
     14     b = numpyro.sample("b", dist.Normal(jnp.ones((2, 1)), 1))
     15     mu = a + x @ b

File ~/numpyro/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     "type": "sample",
    209     "name": name,
   (...)
    218     "infer": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg["value"]

File ~/numpyro/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the "stop" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get("stop"):

File ~/numpyro/numpyro/handlers.py:796, in seed.process_message(self, msg)
    793 if (msg["kwargs"]["rng_key"] is not None) or (msg["value"] is not None):
    794     # no need to create a new key when value is available
    795     return
--> 796 self.rng_key, rng_key_sample = random.split(self.rng_key)
    797 msg["kwargs"]["rng_key"] = rng_key_sample

File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/random.py:240, in split(key, num)
    229 def split(key: KeyArray, num: int = 2) -> KeyArray:
    230   """Splits a PRNG key into `num` new keys by adding a leading axis.
    231 
    232   Args:
   (...)
    238     An array-like object of `num` new PRNG keys.
    239   """
--> 240   key, wrapped = _check_prng_key(key)
    241   return _return_prng_keys(wrapped, _split(key, num))

File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/random.py:80, in _check_prng_key(key)
     75   if config.jax_enable_custom_prng:
     76     warnings.warn(
     77         'Raw arrays as random keys to jax.random functions are deprecated. '
     78         'Assuming valid threefry2x32 key for now.',
     79         FutureWarning)
---> 80   return prng.random_wrap(key, impl=default_prng_impl()), True
     81 else:
     82   raise TypeError(f'unexpected PRNG key type {type(key)}')

File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/prng.py:864, in random_wrap(base_arr, impl)
    862 def random_wrap(base_arr, *, impl):
    863   _check_prng_key_data(impl, base_arr)
--> 864   return random_wrap_p.bind(base_arr, impl=impl)

    [... skipping hidden 2 frame]

File ~/numpyro/venv/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py:1579, in DynamicJaxprTracer._assert_live(self)
   1577 def _assert_live(self) -> None:
   1578   if not self._trace.main.jaxpr_stack:  # type: ignore
-> 1579     raise core.escaped_tracer_error(self, None)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scan_body at /Users/tare/numpyro/numpyro/infer/autoguide.py:937 traced for scan.
------------------------------
The leaked intermediate value was created on line /Users/tare/numpyro/numpyro/handlers.py:796 (process_message). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/Users/tare/numpyro/numpyro/primitives.py:105 (__call__)
/var/folders/1s/d7zzs88n6n709f0w22bql1100000gn/T/ipykernel_23159/1142004465.py:14 (model)
/Users/tare/numpyro/numpyro/primitives.py:222 (sample)
/Users/tare/numpyro/numpyro/primitives.py:47 (apply_stack)
/Users/tare/numpyro/numpyro/handlers.py:796 (process_message)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

@fehiepsi
Copy link
Member

fehiepsi commented Oct 2, 2023

Re AutoDais: I'm not sure if AutoGuideList can work with DAIS because the guide requires model evaluation. I'm also not sure if the algorithm can be used in AutoGuideList. @martinjankowiak do you have some ideas?

Hopefully the tracer issue can be resolved after #1657. I can take a stab at it later of the week. But I think it does not block your PR.

@martinjankowiak
Copy link
Collaborator

Re AutoDais: I'm not sure if AutoGuideList can work with DAIS because the guide requires model evaluation. I'm also not > sure if the algorithm can be used in AutoGuideList. @martinjankowiak do you have some ideas?

i think it'd be quite difficult/awkward since AutoDAIS needs a model density. e.g. things would get especially strange if you tried to conjoin two copies of AutoDAIS. probably it should be disallowed from AutoGuideList

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me overall! Thank you. I just have a small comment.

)
else:
result.update(
part.sample_posterior(rng_key, params, sample_shape=sample_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you usage signature of other algorithm to be the same as AutoDelta? That signature is more general I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce the chance of misunderstanding, can you please clarify what you mean.

The reason for the if-else clause is that AutoDelta and AutoSemiDAIS differ from the other autoguides when it comes to the signature of sample_posterior():

  • AutoDelta and AutoSemiDAIS have sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs)
  • other autoguides have sample_posterior(self, rng_key, params, sample_shape=())

I suppose we could break things if we change the signature of sample_posterior() since sample_shape might have been passed as a positional argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since sample_shape might have been passed as a positional argument

Good point! Maybe we can keep your current implementation here as-is and adding *, before sample_shape in other classes. After the next release, we'll make the signature consistent. Users of the next release will keep informed of the change because they are no longer able to use sample_shape as a positional argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me! I'm happy to create another PR to address this if you want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you!

result = {}
for part in self._guides:
if isinstance(part, numpyro.infer.autoguide.AutoDelta) or isinstance(
part, numpyro.infer.autoguide.AutoSemiDAIS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you can raise ValueError in the constructor if the guide is either AutoDAIS or AutoSemiDAIS. I'm not sure if the algorithm gives reasonable results under AutoGuideList.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also AutoSurrogateLikelihoodDAIS

- append() raises ValueError when guide is AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
@fehiepsi fehiepsi marked this pull request as ready for review October 2, 2023 21:27
@tare
Copy link
Contributor Author

tare commented Oct 2, 2023

@fehiepsi Sorry for the additional commit; I realized I hadn't added AutoGuideList to __all__.

@fehiepsi
Copy link
Member

fehiepsi commented Oct 3, 2023

Thank you, @tare!

@fehiepsi fehiepsi merged commit 1a96406 into pyro-ppl:master Oct 3, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants