Skip to content

Commit

Permalink
Allow callable for vectorization over particles on SVI (#1902)
Browse files Browse the repository at this point in the history
* init

* compatible implementation

* extend vectorization test

* add vectorization to test run

* aadd function assigment test
  • Loading branch information
juanitorduz authored Nov 10, 2024
1 parent b9bacee commit c8a0990
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 31 deletions.
69 changes: 45 additions & 24 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level


def _apply_vmap(fn, keys):
return vmap(fn)(keys)


class ELBO:
"""
Base class for all ELBO objectives.
Expand All @@ -36,7 +40,8 @@ class ELBO:
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
Defaults to True. You can also pass a callable to specify a custom vectorization
strategy, for example `jax.pmap`.
"""

"""
Expand All @@ -49,6 +54,22 @@ class ELBO:
def __init__(self, num_particles=1, vectorize_particles=True):
self.num_particles = num_particles
self.vectorize_particles = vectorize_particles
self.vectorize_particles_fn = self._assign_vectorize_particles_fn(
vectorize_particles
)

def _assign_vectorize_particles_fn(self, vectorize_particles):
"""Assigns a vectorization function to self.vectorize_particles_fn."""
if callable(vectorize_particles):
return vectorize_particles
elif vectorize_particles is True:
return _apply_vmap
elif vectorize_particles is False:
return jax.lax.map
else:
raise ValueError(
"`vectorize_particles` needs to be a boolean or a callable."
)

def loss(
self,
Expand Down Expand Up @@ -123,7 +144,8 @@ class Trace_ELBO(ELBO):
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
Defaults to True. You can also pass a callable to specify a custom vectorization
strategy, for example `jax.pmap`.
:param multi_sample_guide: Whether to make an assumption that the guide proposes
multiple samples.
"""
Expand Down Expand Up @@ -228,10 +250,9 @@ def get_model_density(key, latent):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
else:
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
elbos, mutable_state = self.vectorize_particles_fn(
single_particle_elbo, rng_keys
)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


Expand Down Expand Up @@ -362,10 +383,9 @@ def single_particle_elbo(rng_key):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
else:
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
elbos, mutable_state = self.vectorize_particles_fn(
single_particle_elbo, rng_keys
)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


Expand All @@ -387,7 +407,8 @@ class RenyiELBO(ELBO):
used to form the objective (gradient) estimator. Default is 2.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
Defaults to True. You can also pass a callable to specify a custom vectorization
strategy, for example `jax.pmap`.
Example::
Expand Down Expand Up @@ -504,10 +525,9 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
)

rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
else:
elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys)
elbos, common_plate_scale = self.vectorize_particles_fn(
single_particle_elbo, rng_keys
)
assert common_plate_scale.shape == (self.num_particles,)
assert elbos.shape[0] == self.num_particles
scaled_elbos = (1.0 - self.alpha) * elbos
Expand Down Expand Up @@ -853,10 +873,9 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
else:
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))
return -jnp.mean(
self.vectorize_particles_fn(single_particle_elbo, rng_keys)
)


def get_importance_trace_enum(
Expand Down Expand Up @@ -1043,7 +1062,10 @@ class TraceEnum_ELBO(ELBO):
can_infer_discrete = True

def __init__(
self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True
self,
num_particles=1,
max_plate_nesting=float("inf"),
vectorize_particles=True,
):
self.max_plate_nesting = max_plate_nesting
super().__init__(
Expand Down Expand Up @@ -1221,7 +1243,6 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
else:
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))
return -jnp.mean(
self.vectorize_particles_fn(single_particle_elbo, rng_keys)
)
45 changes: 38 additions & 7 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.elbo import _apply_vmap
from numpyro.primitives import mutable as numpyro_mutable
from numpyro.util import fori_loop

Expand Down Expand Up @@ -163,7 +164,20 @@ def get_renyi(n=N, k=K, fix_indices=True):
assert_allclose(atol, 0.0, atol=1e-5)


def test_vectorized_particle():
def test_assign_vectorize_particles_fn():
elbo = Trace_ELBO()
assert elbo._assign_vectorize_particles_fn(True) == _apply_vmap
assert elbo._assign_vectorize_particles_fn(False) == jax.lax.map
assert elbo._assign_vectorize_particles_fn(jax.pmap) == jax.pmap
assert callable(elbo._assign_vectorize_particles_fn(lambda x: x))


@pytest.mark.parametrize(
argnames="vectorize_particles",
argvalues=[True, False, jax.pmap, lambda x: x],
ids=["vmap", "lax", "pmap", "custom"],
)
def test_vectorized_particle(vectorize_particles):
data = jnp.array([1.0] * 8 + [0.0] * 2)

def model(data):
Expand All @@ -176,13 +190,16 @@ def guide(data):
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

vmap_results = SVI(
model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=True)
results = SVI(
model,
guide,
optim.Adam(0.1),
Trace_ELBO(vectorize_particles=vectorize_particles),
).run(random.PRNGKey(0), 100, data)
map_results = SVI(
model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=False)
).run(random.PRNGKey(0), 100, data)
assert_allclose(vmap_results.losses, map_results.losses, atol=1e-5)
assert_allclose(results.losses, map_results.losses, atol=1e-5)


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
Expand Down Expand Up @@ -219,8 +236,17 @@ def body_fn(i, val):
)


@pytest.mark.parametrize("progress_bar", [True, False])
def test_run(progress_bar):
@pytest.mark.parametrize(
argnames="vectorize_particles",
argvalues=[True, False, jax.pmap, lambda x: x],
ids=["vmap", "lax", "pmap", "custom"],
)
@pytest.mark.parametrize(
argnames="progress_bar",
argvalues=[True, False],
ids=["progress_bar", "no_progress_bar"],
)
def test_run(vectorize_particles, progress_bar):
data = jnp.array([1.0] * 8 + [0.0] * 2)

def model(data):
Expand All @@ -239,7 +265,12 @@ def guide(data):
)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO())
svi = SVI(
model,
guide,
optim.Adam(0.05),
Trace_ELBO(vectorize_particles=vectorize_particles),
)
svi_result = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar)
params, losses = svi_result.params, svi_result.losses
assert losses.shape == (1000,)
Expand Down

0 comments on commit c8a0990

Please sign in to comment.