From 1528036d0023a37fe3fd9934365f48f11be3caab Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 12 Nov 2023 07:37:57 -0500 Subject: [PATCH] Allow AutoSemiDAIS to work without global variable (#1665) * support vae for dais * revise docstring * use local params for base dist and fix kumaraswamy vmap tests --- numpyro/distributions/batch_util.py | 13 +--- numpyro/infer/autoguide.py | 112 +++++++++++++++++++++------- test/infer/test_autoguide.py | 24 ++++++ 3 files changed, 110 insertions(+), 39 deletions(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 0b54fb641..0292e02e1 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -261,17 +261,8 @@ def _vmap_over_kumaraswamy(dist: Kumaraswamy, concentration0=None, concentration dist_axes = _default_vmap_over( dist, concentration0=concentration0, concentration1=concentration1 ) - if isinstance(dist.base_dist, Uniform): - dist_axes.base_dist = vmap_over(dist.base_dist, low=None, high=None) - else: - assert isinstance(dist.base_dist, ExpandedDistribution) - dist_axes.base_dist = vmap_over(dist.base_dist, base_dist=None) - - dist_axes.transforms = [ - vmap_over(dist.transforms[0], exponent=concentration0), - vmap_over(dist.transforms[1], loc=None, scale=None), - vmap_over(dist.transforms[2], exponent=concentration1), - ] + dist_axes.concentration0 = concentration0 + dist_axes.concentration1 = concentration1 return dist_axes diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index fa72710a0..f1e4488b1 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -1250,13 +1250,21 @@ def local_model(theta): during partial momentum refreshments in HMC. Defaults to 0.9. :param float init_scale: Initial scale for the standard deviation of the variational distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1. + :param str subsample_plate: Optional name of the subsample plate site. This is required + when the model has a subsample plate without `subsample_size` specified or + the model has a subsample plate with `subsample_size` equal to the plate size. + :param bool use_global_dais_params: Whether parameters controlling DAIS dynamic + (HMC step size, HMC mass matrix, etc.) should be global (i.e. common to all + data points in the subsample plate) or local (i.e. each data point in the + subsample plate has individual parameters). Note that we do not use global + parameters for the base distribution. """ def __init__( self, model, local_model, - global_guide, + global_guide=None, local_guide=None, *, prefix="auto", @@ -1265,6 +1273,8 @@ def __init__( eta_max=0.1, gamma_init=0.9, init_scale=0.1, + subsample_plate=None, + use_global_dais_params=False, ): # init_loc_fn is only used to inspect the model. super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform) @@ -1289,6 +1299,8 @@ def __init__( self.gamma_init = gamma_init self.K = K self.init_scale = init_scale + self.subsample_plate = subsample_plate + self.use_global_dais_params = use_global_dais_params def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) @@ -1301,6 +1313,17 @@ def _setup_prototype(self, *args, **kwargs): and isinstance(site["args"][1], int) and site["args"][0] > site["args"][1] } + if self.subsample_plate is not None: + subsample_plates[self.subsample_plate] = self.prototype_trace[ + self.subsample_plate + ] + elif not subsample_plates: + # Consider all plates as subsample plates. + subsample_plates = { + name: site + for name, site in self.prototype_trace.items() + if site["type"] == "plate" + } num_plates = len(subsample_plates) assert ( num_plates == 1 @@ -1344,6 +1367,8 @@ def _setup_prototype(self, *args, **kwargs): UnpackTransform(unpack_latent), out_axes=subsample_axes ) plate_full_size, plate_subsample_size = subsample_plates[plate_name]["args"] + if plate_subsample_size is None: + plate_subsample_size = plate_full_size self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size self._local_plate = (plate_name, plate_full_size, plate_subsample_size) @@ -1451,37 +1476,68 @@ def fn(x): D, K = self._local_latent_dim, self.K with numpyro.plate(plate_name, N, subsample_size=subsample_size) as idx: - eta0 = numpyro.param( - "{}_eta0".format(self.prefix), - jnp.ones(N) * self.eta_init, - constraint=constraints.interval(0, self.eta_max), - event_dim=0, - ) - eta_coeff = numpyro.param( - "{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0 - ) + if self.use_global_dais_params: + eta0 = numpyro.param( + "{}_eta0".format(self.prefix), + self.eta_init, + constraint=constraints.interval(0, self.eta_max), + ) + eta0 = jnp.broadcast_to(eta0, idx.shape) + eta_coeff = numpyro.param( + "{}_eta_coeff".format(self.prefix), + 0.0, + ) + eta_coeff = jnp.broadcast_to(eta_coeff, idx.shape) + gamma = numpyro.param( + "{}_gamma".format(self.prefix), + 0.9, + constraint=constraints.interval(0, 1), + ) + gamma = jnp.broadcast_to(gamma, idx.shape) + betas = numpyro.param( + "{}_beta_increments".format(self.prefix), + jnp.ones(K), + constraint=constraints.positive, + ) + betas = jnp.broadcast_to(betas, idx.shape + (K,)) + mass_matrix = numpyro.param( + "{}_mass_matrix".format(self.prefix), + jnp.ones(D), + constraint=constraints.positive, + ) + mass_matrix = jnp.broadcast_to(mass_matrix, idx.shape + (D,)) + else: + eta0 = numpyro.param( + "{}_eta0".format(self.prefix), + jnp.ones(N) * self.eta_init, + constraint=constraints.interval(0, self.eta_max), + event_dim=0, + ) + eta_coeff = numpyro.param( + "{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0 + ) + gamma = numpyro.param( + "{}_gamma".format(self.prefix), + jnp.ones(N) * 0.9, + constraint=constraints.interval(0, 1), + event_dim=0, + ) + betas = numpyro.param( + "{}_beta_increments".format(self.prefix), + jnp.ones((N, K)), + constraint=constraints.positive, + event_dim=1, + ) + mass_matrix = numpyro.param( + "{}_mass_matrix".format(self.prefix), + jnp.ones((N, D)), + constraint=constraints.positive, + event_dim=1, + ) - gamma = numpyro.param( - "{}_gamma".format(self.prefix), - jnp.ones(N) * 0.9, - constraint=constraints.interval(0, 1), - event_dim=0, - ) - betas = numpyro.param( - "{}_beta_increments".format(self.prefix), - jnp.ones((N, K)), - constraint=constraints.positive, - event_dim=1, - ) betas = jnp.cumsum(betas, axis=-1) betas = betas / betas[..., -1:] - mass_matrix = numpyro.param( - "{}_mass_matrix".format(self.prefix), - jnp.ones((N, D)), - constraint=constraints.positive, - event_dim=1, - ) inv_mass_matrix = 0.5 / mass_matrix assert inv_mass_matrix.shape == (subsample_size, D) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index be3238c49..4a658d584 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -1227,3 +1227,27 @@ def model(): ) assert guide_samples["x"].shape == sample_shape + shape assert guide_samples["x2"].shape == sample_shape + shape + + +@pytest.mark.parametrize("use_global_dais_params", [True, False]) +def test_dais_vae(use_global_dais_params): + def model(): + with numpyro.plate("N", 10): + numpyro.sample("x", dist.Normal(jnp.arange(-5, 5), 2)) + + guide = AutoSemiDAIS( + model, model, subsample_plate="N", use_global_dais_params=use_global_dais_params + ) + svi = SVI(model, guide, optax.adam(0.02), Trace_ELBO()) + svi_results = svi.run(random.PRNGKey(0), 3000) + samples = guide.sample_posterior( + random.PRNGKey(1), svi_results.params, sample_shape=(1000,) + ) + if use_global_dais_params: + assert_allclose( + samples["x"].mean(), jnp.arange(-5, 5).mean(), atol=0.1, rtol=0.1 + ) + else: + assert_allclose( + samples["x"].mean(axis=0), jnp.arange(-5, 5), atol=0.2, rtol=0.1 + )