Skip to content

Commit

Permalink
compatible implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 10, 2024
1 parent adb8ef2 commit 6c6760e
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level


def _apply_vmap(fn, keys):
return vmap(fn)(keys)


class ELBO:
"""
Base class for all ELBO objectives.
Expand All @@ -34,9 +38,10 @@ class ELBO:
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Callable to vectorize the computation of ELBOs over
the num_particles-many particles. Defaults to `jax.lax.map`. Other options are
`jax.vmap` and `jax.pmap`.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True. You can also pass a callable to specify a custom vectorization
strategy, for example `jax.pmap`.
"""

"""
Expand All @@ -46,9 +51,25 @@ class ELBO:
"""
can_infer_discrete = False

def __init__(self, num_particles=1, vectorize_particles=jax.lax.map):
def __init__(self, num_particles=1, vectorize_particles=True):
self.num_particles = num_particles
self.vectorize_particles = vectorize_particles
self.vectorize_particles_fn = self._assign_vectorize_particles_fn(
vectorize_particles
)

def _assign_vectorize_particles_fn(self, vectorize_particles):
"""Assigns a vectorization function to self.vectorize_particles_fn."""
if callable(vectorize_particles):
return vectorize_particles
elif vectorize_particles is True:
return _apply_vmap
elif vectorize_particles is False:
return jax.lax.map
else:
raise ValueError(
"`vectorize_particles` needs to be a boolean or a callable."
)

def loss(
self,
Expand Down Expand Up @@ -121,15 +142,16 @@ class Trace_ELBO(ELBO):
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Callable to vectorize the computation of ELBOs over
the num_particles-many particles. Defaults to `jax.lax.map`. Other options are
`jax.vmap` and `jax.pmap`.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True. You can also pass a callable to specify a custom vectorization
strategy, for example `jax.pmap`.
:param multi_sample_guide: Whether to make an assumption that the guide proposes
multiple samples.
"""

def __init__(
self, num_particles=1, vectorize_particles=jax.lax.map, multi_sample_guide=False
self, num_particles=1, vectorize_particles=True, multi_sample_guide=False
):
self.multi_sample_guide = multi_sample_guide
super().__init__(
Expand Down Expand Up @@ -228,7 +250,7 @@ def get_model_density(key, latent):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
elbos, mutable_state = self.vectorize_particles(
elbos, mutable_state = self.vectorize_particles_fn(
single_particle_elbo, rng_keys
)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}
Expand Down Expand Up @@ -361,7 +383,7 @@ def single_particle_elbo(rng_key):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
elbos, mutable_state = self.vectorize_particles(
elbos, mutable_state = self.vectorize_particles_fn(
single_particle_elbo, rng_keys
)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}
Expand All @@ -383,9 +405,10 @@ class RenyiELBO(ELBO):
Here :math:`\alpha \neq 1`. Default is 0.
:param num_particles: The number of particles/samples
used to form the objective (gradient) estimator. Default is 2.
:param vectorize_particles: Callable to vectorize the computation of ELBOs over
the num_particles-many particles. Defaults to `jax.lax.map`. Other options are
`jax.vmap` and `jax.pmap`.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True. You can also pass a callable to specify a custom vectorization
strategy, for example `jax.pmap`.
Example::
Expand Down Expand Up @@ -502,7 +525,7 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
)

rng_keys = random.split(rng_key, self.num_particles)
elbos, common_plate_scale = self.vectorize_particles(
elbos, common_plate_scale = self.vectorize_particles_fn(
single_particle_elbo, rng_keys
)
assert common_plate_scale.shape == (self.num_particles,)
Expand Down Expand Up @@ -850,7 +873,9 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(self.vectorize_particles(single_particle_elbo, rng_keys))
return -jnp.mean(
self.vectorize_particles_fn(single_particle_elbo, rng_keys)
)


def get_importance_trace_enum(
Expand Down Expand Up @@ -1040,7 +1065,7 @@ def __init__(
self,
num_particles=1,
max_plate_nesting=float("inf"),
vectorize_particles=jax.lax.map,
vectorize_particles=True,
):
self.max_plate_nesting = max_plate_nesting
super().__init__(
Expand Down Expand Up @@ -1218,4 +1243,6 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(self.vectorize_particles(single_particle_elbo, rng_keys))
return -jnp.mean(
self.vectorize_particles_fn(single_particle_elbo, rng_keys)
)

0 comments on commit 6c6760e

Please sign in to comment.