diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index c0aaa31182..d40a88ed69 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -34,9 +34,9 @@ class ELBO: :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. - :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the - num_particles-many particles in parallel. If False use `jax.lax.map`. - Defaults to True. + :param vectorize_particles: Callable to vectorize the computation of ELBOs over + the num_particles-many particles. Defaults to `jax.lax.map`. Other options are + `jax.vmap` and `jax.pmap`. """ """ @@ -46,7 +46,7 @@ class ELBO: """ can_infer_discrete = False - def __init__(self, num_particles=1, vectorize_particles=True): + def __init__(self, num_particles=1, vectorize_particles=jax.lax.map): self.num_particles = num_particles self.vectorize_particles = vectorize_particles @@ -121,15 +121,15 @@ class Trace_ELBO(ELBO): :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. - :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the - num_particles-many particles in parallel. If False use `jax.lax.map`. - Defaults to True. + :param vectorize_particles: Callable to vectorize the computation of ELBOs over + the num_particles-many particles. Defaults to `jax.lax.map`. Other options are + `jax.vmap` and `jax.pmap`. :param multi_sample_guide: Whether to make an assumption that the guide proposes multiple samples. """ def __init__( - self, num_particles=1, vectorize_particles=True, multi_sample_guide=False + self, num_particles=1, vectorize_particles=jax.lax.map, multi_sample_guide=False ): self.multi_sample_guide = multi_sample_guide super().__init__( @@ -228,10 +228,9 @@ def get_model_density(key, latent): return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) - else: - elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) + elbos, mutable_state = self.vectorize_particles( + single_particle_elbo, rng_keys + ) return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} @@ -362,10 +361,9 @@ def single_particle_elbo(rng_key): return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) - else: - elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) + elbos, mutable_state = self.vectorize_particles( + single_particle_elbo, rng_keys + ) return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} @@ -385,9 +383,9 @@ class RenyiELBO(ELBO): Here :math:`\alpha \neq 1`. Default is 0. :param num_particles: The number of particles/samples used to form the objective (gradient) estimator. Default is 2. - :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the - num_particles-many particles in parallel. If False use `jax.lax.map`. - Defaults to True. + :param vectorize_particles: Callable to vectorize the computation of ELBOs over + the num_particles-many particles. Defaults to `jax.lax.map`. Other options are + `jax.vmap` and `jax.pmap`. Example:: @@ -504,10 +502,9 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs): ) rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys) - else: - elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys) + elbos, common_plate_scale = self.vectorize_particles( + single_particle_elbo, rng_keys + ) assert common_plate_scale.shape == (self.num_particles,) assert elbos.shape[0] == self.num_particles scaled_elbos = (1.0 - self.alpha) * elbos @@ -853,10 +850,7 @@ def single_particle_elbo(rng_key): return -single_particle_elbo(rng_key) else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) - else: - return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) + return -jnp.mean(self.vectorize_particles(single_particle_elbo, rng_keys)) def get_importance_trace_enum( @@ -1043,7 +1037,10 @@ class TraceEnum_ELBO(ELBO): can_infer_discrete = True def __init__( - self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True + self, + num_particles=1, + max_plate_nesting=float("inf"), + vectorize_particles=jax.lax.map, ): self.max_plate_nesting = max_plate_nesting super().__init__( @@ -1221,7 +1218,4 @@ def single_particle_elbo(rng_key): return -single_particle_elbo(rng_key) else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) - else: - return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) + return -jnp.mean(self.vectorize_particles(single_particle_elbo, rng_keys))