Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
yayami3 committed Dec 16, 2023
1 parent 5e76e38 commit b1eef4c
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,8 @@ def model(data):
with numpyro.plate("size", np.size(data["y"])):
numpyro.sample("obs", dist.Normal(w * data["x"] + b, sigma), obs=data["y"])

hmc_kernel = HMC(model, num_steps=5)
with pytest.warns(UserWarning, match="If both"):
hmc_kernel = HMC(model, num_steps=5)
mcmc = MCMC(
hmc_kernel,
num_samples=1000,
Expand All @@ -1090,3 +1091,35 @@ def model(data):
mcmc.run(rng_key, data, extra_fields=("num_steps",))
num_steps_list = np.array(mcmc.get_extra_fields()["num_steps"])
assert all(step == 5 for step in num_steps_list)


@pytest.mark.parametrize("num_steps", [None, 10])
def test_none_trajectory_length(num_steps):
data = dict()
data["x"] = np.random.rand(10)
data["y"] = data["x"] + np.random.rand(10) * 0.1

def model(data):
w = numpyro.sample("w", dist.Normal(10, 1))
b = numpyro.sample("b", dist.Normal(1, 1))
sigma = numpyro.sample("sigma", dist.Gamma(1, 2))
with numpyro.plate("size", np.size(data["y"])):
numpyro.sample("obs", dist.Normal(w * data["x"] + b, sigma), obs=data["y"])

if num_steps is None:
with pytest.raises(ValueError, match="At least one of"):
hmc_kernel = HMC(model, num_steps=num_steps, trajectory_length=None)
else:
hmc_kernel = HMC(model, num_steps=num_steps, trajectory_length=None)

mcmc = MCMC(
hmc_kernel,
num_samples=1000,
num_warmup=1000,
num_chains=1,
)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, data)
mcmc.run(rng_key, data, extra_fields=("num_steps",))
num_steps_list = np.array(mcmc.get_extra_fields()["num_steps"])
assert all(step == num_steps for step in num_steps_list)

0 comments on commit b1eef4c

Please sign in to comment.