diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index c0aaa3118..a255fbb7f 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -26,6 +26,10 @@ from numpyro.util import _validate_model, check_model_guide_match, find_stack_level +def _apply_vmap(fn, keys): + return vmap(fn)(keys) + + class ELBO: """ Base class for all ELBO objectives. @@ -36,7 +40,8 @@ class ELBO: (gradient) estimators. :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the num_particles-many particles in parallel. If False use `jax.lax.map`. - Defaults to True. + Defaults to True. You can also pass a callable to specify a custom vectorization + strategy, for example `jax.pmap`. """ """ @@ -49,6 +54,22 @@ class ELBO: def __init__(self, num_particles=1, vectorize_particles=True): self.num_particles = num_particles self.vectorize_particles = vectorize_particles + self.vectorize_particles_fn = self._assign_vectorize_particles_fn( + vectorize_particles + ) + + def _assign_vectorize_particles_fn(self, vectorize_particles): + """Assigns a vectorization function to self.vectorize_particles_fn.""" + if callable(vectorize_particles): + return vectorize_particles + elif vectorize_particles is True: + return _apply_vmap + elif vectorize_particles is False: + return jax.lax.map + else: + raise ValueError( + "`vectorize_particles` needs to be a boolean or a callable." + ) def loss( self, @@ -123,7 +144,8 @@ class Trace_ELBO(ELBO): (gradient) estimators. :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the num_particles-many particles in parallel. If False use `jax.lax.map`. - Defaults to True. + Defaults to True. You can also pass a callable to specify a custom vectorization + strategy, for example `jax.pmap`. :param multi_sample_guide: Whether to make an assumption that the guide proposes multiple samples. """ @@ -228,10 +250,9 @@ def get_model_density(key, latent): return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) - else: - elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) + elbos, mutable_state = self.vectorize_particles_fn( + single_particle_elbo, rng_keys + ) return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} @@ -362,10 +383,9 @@ def single_particle_elbo(rng_key): return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) - else: - elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) + elbos, mutable_state = self.vectorize_particles_fn( + single_particle_elbo, rng_keys + ) return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} @@ -387,7 +407,8 @@ class RenyiELBO(ELBO): used to form the objective (gradient) estimator. Default is 2. :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the num_particles-many particles in parallel. If False use `jax.lax.map`. - Defaults to True. + Defaults to True. You can also pass a callable to specify a custom vectorization + strategy, for example `jax.pmap`. Example:: @@ -504,10 +525,9 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs): ) rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys) - else: - elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys) + elbos, common_plate_scale = self.vectorize_particles_fn( + single_particle_elbo, rng_keys + ) assert common_plate_scale.shape == (self.num_particles,) assert elbos.shape[0] == self.num_particles scaled_elbos = (1.0 - self.alpha) * elbos @@ -853,10 +873,9 @@ def single_particle_elbo(rng_key): return -single_particle_elbo(rng_key) else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) - else: - return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) + return -jnp.mean( + self.vectorize_particles_fn(single_particle_elbo, rng_keys) + ) def get_importance_trace_enum( @@ -1043,7 +1062,10 @@ class TraceEnum_ELBO(ELBO): can_infer_discrete = True def __init__( - self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True + self, + num_particles=1, + max_plate_nesting=float("inf"), + vectorize_particles=True, ): self.max_plate_nesting = max_plate_nesting super().__init__( @@ -1221,7 +1243,6 @@ def single_particle_elbo(rng_key): return -single_particle_elbo(rng_key) else: rng_keys = random.split(rng_key, self.num_particles) - if self.vectorize_particles: - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) - else: - return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) + return -jnp.mean( + self.vectorize_particles_fn(single_particle_elbo, rng_keys) + ) diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index f52c1cef2..1ae8012ff 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -25,6 +25,7 @@ TraceGraph_ELBO, TraceMeanField_ELBO, ) +from numpyro.infer.elbo import _apply_vmap from numpyro.primitives import mutable as numpyro_mutable from numpyro.util import fori_loop @@ -163,7 +164,20 @@ def get_renyi(n=N, k=K, fix_indices=True): assert_allclose(atol, 0.0, atol=1e-5) -def test_vectorized_particle(): +def test_assign_vectorize_particles_fn(): + elbo = Trace_ELBO() + assert elbo._assign_vectorize_particles_fn(True) == _apply_vmap + assert elbo._assign_vectorize_particles_fn(False) == jax.lax.map + assert elbo._assign_vectorize_particles_fn(jax.pmap) == jax.pmap + assert callable(elbo._assign_vectorize_particles_fn(lambda x: x)) + + +@pytest.mark.parametrize( + argnames="vectorize_particles", + argvalues=[True, False, jax.pmap, lambda x: x], + ids=["vmap", "lax", "pmap", "custom"], +) +def test_vectorized_particle(vectorize_particles): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): @@ -176,13 +190,16 @@ def guide(data): beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) - vmap_results = SVI( - model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=True) + results = SVI( + model, + guide, + optim.Adam(0.1), + Trace_ELBO(vectorize_particles=vectorize_particles), ).run(random.PRNGKey(0), 100, data) map_results = SVI( model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=False) ).run(random.PRNGKey(0), 100, data) - assert_allclose(vmap_results.losses, map_results.losses, atol=1e-5) + assert_allclose(results.losses, map_results.losses, atol=1e-5) @pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) @@ -219,8 +236,17 @@ def body_fn(i, val): ) -@pytest.mark.parametrize("progress_bar", [True, False]) -def test_run(progress_bar): +@pytest.mark.parametrize( + argnames="vectorize_particles", + argvalues=[True, False, jax.pmap, lambda x: x], + ids=["vmap", "lax", "pmap", "custom"], +) +@pytest.mark.parametrize( + argnames="progress_bar", + argvalues=[True, False], + ids=["progress_bar", "no_progress_bar"], +) +def test_run(vectorize_particles, progress_bar): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): @@ -239,7 +265,12 @@ def guide(data): ) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) - svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO()) + svi = SVI( + model, + guide, + optim.Adam(0.05), + Trace_ELBO(vectorize_particles=vectorize_particles), + ) svi_result = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar) params, losses = svi_result.params, svi_result.losses assert losses.shape == (1000,)