Skip to content

Commit

Permalink
Allow AutoSemiDAIS to work without global variable (#1665)
Browse files Browse the repository at this point in the history
* support vae for dais

* revise docstring

* use local params for base dist and fix kumaraswamy vmap tests
  • Loading branch information
fehiepsi authored Nov 12, 2023
1 parent f5bd186 commit 1528036
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 39 deletions.
13 changes: 2 additions & 11 deletions numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,8 @@ def _vmap_over_kumaraswamy(dist: Kumaraswamy, concentration0=None, concentration
dist_axes = _default_vmap_over(
dist, concentration0=concentration0, concentration1=concentration1
)
if isinstance(dist.base_dist, Uniform):
dist_axes.base_dist = vmap_over(dist.base_dist, low=None, high=None)
else:
assert isinstance(dist.base_dist, ExpandedDistribution)
dist_axes.base_dist = vmap_over(dist.base_dist, base_dist=None)

dist_axes.transforms = [
vmap_over(dist.transforms[0], exponent=concentration0),
vmap_over(dist.transforms[1], loc=None, scale=None),
vmap_over(dist.transforms[2], exponent=concentration1),
]
dist_axes.concentration0 = concentration0
dist_axes.concentration1 = concentration1
return dist_axes


Expand Down
112 changes: 84 additions & 28 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,13 +1250,21 @@ def local_model(theta):
during partial momentum refreshments in HMC. Defaults to 0.9.
:param float init_scale: Initial scale for the standard deviation of the variational
distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1.
:param str subsample_plate: Optional name of the subsample plate site. This is required
when the model has a subsample plate without `subsample_size` specified or
the model has a subsample plate with `subsample_size` equal to the plate size.
:param bool use_global_dais_params: Whether parameters controlling DAIS dynamic
(HMC step size, HMC mass matrix, etc.) should be global (i.e. common to all
data points in the subsample plate) or local (i.e. each data point in the
subsample plate has individual parameters). Note that we do not use global
parameters for the base distribution.
"""

def __init__(
self,
model,
local_model,
global_guide,
global_guide=None,
local_guide=None,
*,
prefix="auto",
Expand All @@ -1265,6 +1273,8 @@ def __init__(
eta_max=0.1,
gamma_init=0.9,
init_scale=0.1,
subsample_plate=None,
use_global_dais_params=False,
):
# init_loc_fn is only used to inspect the model.
super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform)
Expand All @@ -1289,6 +1299,8 @@ def __init__(
self.gamma_init = gamma_init
self.K = K
self.init_scale = init_scale
self.subsample_plate = subsample_plate
self.use_global_dais_params = use_global_dais_params

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
Expand All @@ -1301,6 +1313,17 @@ def _setup_prototype(self, *args, **kwargs):
and isinstance(site["args"][1], int)
and site["args"][0] > site["args"][1]
}
if self.subsample_plate is not None:
subsample_plates[self.subsample_plate] = self.prototype_trace[
self.subsample_plate
]
elif not subsample_plates:
# Consider all plates as subsample plates.
subsample_plates = {
name: site
for name, site in self.prototype_trace.items()
if site["type"] == "plate"
}
num_plates = len(subsample_plates)
assert (
num_plates == 1
Expand Down Expand Up @@ -1344,6 +1367,8 @@ def _setup_prototype(self, *args, **kwargs):
UnpackTransform(unpack_latent), out_axes=subsample_axes
)
plate_full_size, plate_subsample_size = subsample_plates[plate_name]["args"]
if plate_subsample_size is None:
plate_subsample_size = plate_full_size
self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size
self._local_plate = (plate_name, plate_full_size, plate_subsample_size)

Expand Down Expand Up @@ -1451,37 +1476,68 @@ def fn(x):
D, K = self._local_latent_dim, self.K

with numpyro.plate(plate_name, N, subsample_size=subsample_size) as idx:
eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
jnp.ones(N) * self.eta_init,
constraint=constraints.interval(0, self.eta_max),
event_dim=0,
)
eta_coeff = numpyro.param(
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0
)
if self.use_global_dais_params:
eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
self.eta_init,
constraint=constraints.interval(0, self.eta_max),
)
eta0 = jnp.broadcast_to(eta0, idx.shape)
eta_coeff = numpyro.param(
"{}_eta_coeff".format(self.prefix),
0.0,
)
eta_coeff = jnp.broadcast_to(eta_coeff, idx.shape)
gamma = numpyro.param(
"{}_gamma".format(self.prefix),
0.9,
constraint=constraints.interval(0, 1),
)
gamma = jnp.broadcast_to(gamma, idx.shape)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones(K),
constraint=constraints.positive,
)
betas = jnp.broadcast_to(betas, idx.shape + (K,))
mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones(D),
constraint=constraints.positive,
)
mass_matrix = jnp.broadcast_to(mass_matrix, idx.shape + (D,))
else:
eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
jnp.ones(N) * self.eta_init,
constraint=constraints.interval(0, self.eta_max),
event_dim=0,
)
eta_coeff = numpyro.param(
"{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0
)
gamma = numpyro.param(
"{}_gamma".format(self.prefix),
jnp.ones(N) * 0.9,
constraint=constraints.interval(0, 1),
event_dim=0,
)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones((N, K)),
constraint=constraints.positive,
event_dim=1,
)
mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones((N, D)),
constraint=constraints.positive,
event_dim=1,
)

gamma = numpyro.param(
"{}_gamma".format(self.prefix),
jnp.ones(N) * 0.9,
constraint=constraints.interval(0, 1),
event_dim=0,
)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones((N, K)),
constraint=constraints.positive,
event_dim=1,
)
betas = jnp.cumsum(betas, axis=-1)
betas = betas / betas[..., -1:]

mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones((N, D)),
constraint=constraints.positive,
event_dim=1,
)
inv_mass_matrix = 0.5 / mass_matrix
assert inv_mass_matrix.shape == (subsample_size, D)

Expand Down
24 changes: 24 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,3 +1227,27 @@ def model():
)
assert guide_samples["x"].shape == sample_shape + shape
assert guide_samples["x2"].shape == sample_shape + shape


@pytest.mark.parametrize("use_global_dais_params", [True, False])
def test_dais_vae(use_global_dais_params):
def model():
with numpyro.plate("N", 10):
numpyro.sample("x", dist.Normal(jnp.arange(-5, 5), 2))

guide = AutoSemiDAIS(
model, model, subsample_plate="N", use_global_dais_params=use_global_dais_params
)
svi = SVI(model, guide, optax.adam(0.02), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(0), 3000)
samples = guide.sample_posterior(
random.PRNGKey(1), svi_results.params, sample_shape=(1000,)
)
if use_global_dais_params:
assert_allclose(
samples["x"].mean(), jnp.arange(-5, 5).mean(), atol=0.1, rtol=0.1
)
else:
assert_allclose(
samples["x"].mean(axis=0), jnp.arange(-5, 5), atol=0.2, rtol=0.1
)

0 comments on commit 1528036

Please sign in to comment.