Skip to content

Commit

Permalink
fixed haiku for ecs (#1750)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning authored Feb 28, 2024
1 parent e6c187c commit f997da2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
16 changes: 11 additions & 5 deletions numpyro/contrib/ecs_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def construct_proxy_fn(
num_blocks=1,
):
ref_params = {
name: biject_to(prototype_trace[name]["fn"].support).inv(value)
name: (
biject_to(prototype_trace[name]["fn"].support).inv(value)
if prototype_trace[name]["type"] == "sample"
else value
)
for name, value in reference_params.items()
}

Expand All @@ -131,7 +135,11 @@ def log_likelihood(params_flat, subsample_indices=None):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
params = {
name: biject_to(prototype_trace[name]["fn"].support)(value)
name: (
biject_to(prototype_trace[name]["fn"].support)(value)
if prototype_trace[name]["type"] == "sample"
else value
)
for name, value in params.items()
}
with (
Expand Down Expand Up @@ -167,9 +175,7 @@ def log_likelihood_sum(params_flat, subsample_indices=None):
elif 1:
TPState = TaylorOneProxyState
else:
raise ValueError(
"Taylor proxy only defined for first and second degree."
)
raise ValueError("Taylor proxy only defined for first and second degree.")

# those stats are dict keyed by subsample names
ref_sum_log_lik = log_likelihood_sum(ref_params_flat)
Expand Down
46 changes: 46 additions & 0 deletions test/contrib/test_esc_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@

from jax import numpy as jnp, random

from numpyro import plate, prng_key, sample
from numpyro.contrib.ecs_proxies import block_update
from numpyro.contrib.module import random_haiku_module
from numpyro.distributions import Cauchy, Normal
from numpyro.handlers import seed
from numpyro.infer import HMC, HMCECS, MCMC, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro.optim import Adam


@pytest.mark.parametrize("num_blocks", [1, 2, 50, 100])
Expand All @@ -26,3 +33,42 @@ def test_block_update_partitioning(num_blocks):
)

assert gibbs_state == new_gibbs_state


def test_haiku_compatiable():
try:
import haiku as hk # noqa: F401

data_points = 6
x_dim = 4

def model(x, y):
net = random_haiku_module(
"net",
hk.transform(lambda x: hk.Linear(1)(x)),
prior={"linear.b": Cauchy(), "linear.w": Normal()},
input_shape=(1, x_dim),
)

with plate("data", data_points, subsample_size=2) as idx:
yb = y[idx]
xb = x[idx]
sample("y", Normal(net(xb).squeeze()), obs=yb)

x = jnp.ones((data_points, x_dim))
y = jnp.array((data_points, 0))

with seed(rng_seed=0):
svi = SVI(model, AutoDelta(model), Adam(step_size=1e-3), Trace_ELBO())
svi_result = svi.run(prng_key(), 1, x, y)
ref_params = {
k.removesuffix("_auto_loc"): v for k, v in svi_result.params.items()
}

proxy = HMCECS.taylor_proxy(ref_params, degree=2)
kernel = HMCECS(HMC(model), num_blocks=2, proxy=proxy)

mcmc = MCMC(kernel, num_warmup=2, num_samples=2)
mcmc.run(prng_key(), x, y)
except ImportError:
pass

0 comments on commit f997da2

Please sign in to comment.