Skip to content

Commit

Permalink
Bump to 0.9.0 (#1310)
Browse files Browse the repository at this point in the history
* Add loose strategy for MCMC

* merge svi and mcmc plate warning strategies

* fix failing tests

* validate model accross ELBOs

* update vae example

* fix typos

* Bump to version 0.9.0

* Fix failing tests

* Fix warnings in tests/examples

* relax funsor requirement

* Move optax_to_numpyro to optim

* skip prodlda test on CI
  • Loading branch information
fehiepsi authored Jan 29, 2022
1 parent 5d9e033 commit 7084aaa
Show file tree
Hide file tree
Showing 42 changed files with 178 additions and 241 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ jobs:
run: |
sudo apt install -y pandoc gsfonts
python -m pip install --upgrade pip
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install jaxlib
pip install jax
pip install .[doc,test]
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -r docs/requirements.txt
pip freeze
- name: Lint with flake8
Expand Down Expand Up @@ -64,10 +64,10 @@ jobs:
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install jaxlib
pip install jax
pip install .[dev,test]
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip freeze
- name: Test with pytest
run: |
Expand All @@ -93,10 +93,10 @@ jobs:
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install jaxlib
pip install jax
pip install .[dev,test]
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip freeze
- name: Test with pytest
run: |
Expand Down Expand Up @@ -129,10 +129,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install jaxlib
pip install jax
pip install .[dev,examples,test]
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip freeze
- name: Test with pytest
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ SM3

Optax support
-------------
.. autofunction:: numpyro.contrib.optim.optax_to_numpyro
.. autofunction:: numpyro.optim.optax_to_numpyro
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="AR2 example")
parser.add_argument("--num-data", nargs="?", default=142, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
30 changes: 25 additions & 5 deletions examples/capture_recapture.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def transition_fn(carry, y):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
Expand Down Expand Up @@ -112,7 +116,11 @@ def transition_fn(carry, y):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
Expand Down Expand Up @@ -160,7 +168,11 @@ def transition_fn(carry, y):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
Expand Down Expand Up @@ -202,7 +214,11 @@ def transition_fn(carry, y):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
Expand Down Expand Up @@ -249,7 +265,11 @@ def transition_fn(carry, y):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
Expand Down
2 changes: 1 addition & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-samples", default=1000, type=int, help="number of samples"
Expand Down
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(
description="Non-centered reparameterization example"
)
Expand Down
4 changes: 2 additions & 2 deletions examples/gaussian_shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run_inference(args, data):
num_warmup=args.num_warmup,
num_samples=args.num_samples,
)
mcmc.run(random.PRNGKey(2), **data)
mcmc.run(random.PRNGKey(2), **data, enum=args.enum)
mcmc.print_summary()
mcmc_samples = mcmc.get_samples()

Expand Down Expand Up @@ -123,7 +123,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model")
parser.add_argument("--num-categories", default=3, type=int)
parser.add_argument("--num-words", default=10, type=int)
Expand Down
38 changes: 31 additions & 7 deletions examples/hmm_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def transition_fn(carry, y):
x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
with numpyro.plate("tones", data_dim, dim=-1):
numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
return (x, t + 1), None
Expand Down Expand Up @@ -127,7 +131,11 @@ def transition_fn(carry, y):
x_prev, y_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
# Note the broadcasting tricks here: to index probs_y on tensors x and y,
# we also need a final tensor for the tones dimension. This is conveniently
# provided by the plate associated with that dimension.
Expand Down Expand Up @@ -175,8 +183,16 @@ def transition_fn(carry, y):
w_prev, x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
w = numpyro.sample(
"w",
dist.Categorical(probs_w[w_prev]),
infer={"enumerate": "parallel"},
)
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
# Note the broadcasting tricks here: to index probs_y on tensors x and y,
# we also need a final tensor for the tones dimension. This is conveniently
# provided by the plate associated with that dimension.
Expand Down Expand Up @@ -224,8 +240,16 @@ def transition_fn(carry, y):
w_prev, x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev]))
w = numpyro.sample(
"w",
dist.Categorical(probs_w[w_prev]),
infer={"enumerate": "parallel"},
)
x = numpyro.sample(
"x",
dist.Categorical(Vindex(probs_x)[w, x_prev]),
infer={"enumerate": "parallel"},
)
with numpyro.plate("tones", data_dim, dim=-1) as tones:
numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y)
return (w, x, t + 1), None
Expand Down Expand Up @@ -275,7 +299,7 @@ def transition_fn(carry, y):
with mask(mask=(t < lengths)[..., None]):
probs_x_t = Vindex(probs_x)[x_prev, x_curr]
x_prev, x_curr = x_curr, numpyro.sample(
"x", dist.Categorical(probs_x_t)
"x", dist.Categorical(probs_x_t), infer={"enumerate": "parallel"}
)
with numpyro.plate("tones", data_dim, dim=-1):
probs_y_t = probs_y[x_curr.squeeze(-1)]
Expand Down
2 changes: 1 addition & 1 deletion examples/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Holt-Winters")
parser.add_argument("--T", nargs="?", default=6, type=int)
parser.add_argument("--future", nargs="?", default=1, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/horseshoe_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Horseshoe regression example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_thanksgiving_days(dates):

def get_floating_days_indicators(dates):
def encode(x):
return jnp.array(x.values, dtype=jnp.int64)
return jnp.array(x.values, dtype=jnp.result_type(int))

return {
"labour_days_indicator": encode(get_labour_days(dates)),
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def body_fn(i, val):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Predator-Prey Model")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/prodlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(
description="Probabilistic topic modelling with Flax and Haiku"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/proportion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Testing whether ")
parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int)
parser.add_argument("--num-warmup", nargs="?", default=600, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="Thompson sampling example")
parser.add_argument(
"--num-random", nargs="?", default=2, type=int, help="number of random draws"
Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(
description="UCBadmit gender discrimination using HMC"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def reconstruct_img(epoch, rng_key):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
assert numpyro.__version__.startswith("0.9.0")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-epochs", default=15, type=int, help="number of training epochs"
Expand Down
Loading

0 comments on commit 7084aaa

Please sign in to comment.