Skip to content

Commit

Permalink
Make possible to fix the number of LF steps and tune the step size (#…
Browse files Browse the repository at this point in the history
…1698)

* make possible to fix the number of LF steps per sample and tune the step size

* fix typo

* fix version jaxns==2.2.6

* warn only when called from HMC

* add test

* lint
  • Loading branch information
yayami3 authored Dec 17, 2023
1 parent b16741c commit c914c53
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ funsor
ipython<=8.6.0 # strict the version for https://github.com/ipython/ipython/issues/13845
jax
jaxlib
jaxns>=2.0.1
jaxns==2.2.6
Jinja2<3.1
matplotlib
multipledispatch
Expand Down
29 changes: 26 additions & 3 deletions numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial
import math
import os
import warnings

from jax import device_put, lax, random, vmap
from jax.flatten_util import ravel_pytree
Expand All @@ -19,7 +20,12 @@
warmup_adapter,
)
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
from numpyro.infer.util import (
ParamInfo,
find_stack_level,
init_to_uniform,
initialize_model,
)
from numpyro.util import cond, fori_loop, identity, is_prng_key

HMCState = namedtuple(
Expand Down Expand Up @@ -54,6 +60,7 @@
- **trajectory_length** - The amount of time to run HMC dynamics in each sampling step.
This field is not used in NUTS.
- **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics).
In HMC sampler, `trajectory_length` should be None for step_size to be adapted.
In NUTS sampler, the tree depth of a trajectory can be computed from this field
with `tree_depth = np.log2(num_steps).astype(int) + 1`.
- **accept_prob** - Acceptance probability of the proposal. Note that ``z``
Expand Down Expand Up @@ -362,8 +369,10 @@ def _hmc_next(
num_steps = 1
else:
num_steps = _get_num_steps(step_size, trajectory_length)
# makes sure trajectory length is constant, rather than step_size * num_steps
step_size = trajectory_length / num_steps

if trajectory_length is not None:
# makes sure trajectory length is constant, rather than step_size * num_steps
step_size = trajectory_length / num_steps
vv_state_new = fori_loop(
0,
num_steps,
Expand Down Expand Up @@ -618,6 +627,20 @@ def __init__(
):
if not (model is None) ^ (potential_fn is None):
raise ValueError("Only one of `model` or `potential_fn` must be specified.")
if type(self) is HMC:
if (num_steps is None) & (trajectory_length is None):
raise ValueError(
"At least one of `num_steps` or `trajectory_length` must be specified."
)
if (
adapt_step_size
& (num_steps is not None)
& (trajectory_length is not None)
):
warnings.warn(
"If both `num_steps` and `trajectory_length` are specified step size can't be adapted",
stacklevel=find_stack_level(),
)
self._model = model
self._potential_fn = potential_fn
self._kinetic_fn = (
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def warmup(
**kwargs,
):
"""
Run the MCMC warmup adaptation phase. After this call, `self.warmup_state` will be set
Run the MCMC warmup adaptation phase. After this call, `self.post_warmup_state` will be set
and the :meth:`run` method will skip the warmup adaptation phase. To run `warmup` again
for the new data, it is required to run :meth:`warmup` again.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"flax",
"funsor>=0.4.1",
"graphviz",
"jaxns>=2.0.1",
"jaxns==2.2.6",
"matplotlib",
"optax>=0.0.6",
"pylab-sdk", # jaxns dependency
Expand Down
35 changes: 34 additions & 1 deletion test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,8 @@ def model(data):
with numpyro.plate("size", np.size(data["y"])):
numpyro.sample("obs", dist.Normal(w * data["x"] + b, sigma), obs=data["y"])

hmc_kernel = HMC(model, num_steps=5)
with pytest.warns(UserWarning, match="If both"):
hmc_kernel = HMC(model, num_steps=5)
mcmc = MCMC(
hmc_kernel,
num_samples=1000,
Expand All @@ -1090,3 +1091,35 @@ def model(data):
mcmc.run(rng_key, data, extra_fields=("num_steps",))
num_steps_list = np.array(mcmc.get_extra_fields()["num_steps"])
assert all(step == 5 for step in num_steps_list)


@pytest.mark.parametrize("num_steps", [None, 10])
def test_none_trajectory_length(num_steps):
data = dict()
data["x"] = np.random.rand(10)
data["y"] = data["x"] + np.random.rand(10) * 0.1

def model(data):
w = numpyro.sample("w", dist.Normal(10, 1))
b = numpyro.sample("b", dist.Normal(1, 1))
sigma = numpyro.sample("sigma", dist.Gamma(1, 2))
with numpyro.plate("size", np.size(data["y"])):
numpyro.sample("obs", dist.Normal(w * data["x"] + b, sigma), obs=data["y"])

if num_steps is None:
with pytest.raises(ValueError, match="At least one of"):
hmc_kernel = HMC(model, num_steps=num_steps, trajectory_length=None)
else:
hmc_kernel = HMC(model, num_steps=num_steps, trajectory_length=None)

mcmc = MCMC(
hmc_kernel,
num_samples=1000,
num_warmup=1000,
num_chains=1,
)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, data)
mcmc.run(rng_key, data, extra_fields=("num_steps",))
num_steps_list = np.array(mcmc.get_extra_fields()["num_steps"])
assert all(step == num_steps for step in num_steps_list)

0 comments on commit c914c53

Please sign in to comment.