Skip to content

Commit

Permalink
Bump to 0.6.0 (#959)
Browse files Browse the repository at this point in the history
* sketch the plan

* add a running implementation (not working yet

* remove unnecessary changes

* sequential update

* temp save

* add a working implementation

* add 24d example

* fix lint

* test for the order

* merge master

* fix bug at reset momentum

* add various mh functions

* add various discrete gibbs function method

* change stay_prob to modified to avoid confusing users

* expose more information for mixed hmc

* sketch an implementation

* temp save

* temp save

* finish the implementation

* keep kinetic energy

* add temperature experiment

* add dual averaging

* add various debug statements

* fix bugs

* clean up and separating out clock adapter; but target distribution is wrong due to a bug somewhere

* clean up

* add comments and an example

* make sure forward mode work

* add docs for new HMC fields

* add tests for mixedhmc

* fix step_size bug

* use modified=False

* tests pass with the fix

* skip print summary

* adjust trajectory length

* port update_version script from Pyro

* pin jax/jaxlib versions

* run isort

* fix some issues during collection notes

* use result_type instead of canonicalize_dtype

* fix lint

* change get_dtype  to jnp.result_type

* add print summary

* fix compiling issue for mcmc

* also try to avoid compiling issue in other samplers

* also fix compiling issue in barkermh

* convert init params types to strong types

* address comments

* fix wrong docs

* run isort
  • Loading branch information
fehiepsi authored Mar 16, 2021
1 parent af06eda commit ecd6255
Show file tree
Hide file tree
Showing 41 changed files with 109 additions and 79 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Pyro users will note that the API for model specification and inference is large

## Installation

> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) if you want to use GPUs on Windows.
> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows.
To install NumPyro with a CPU version of JAX, you can use pip:

Expand Down
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,14 @@ real_vector
-----------
.. autodata:: numpyro.distributions.constraints.real_vector

softplus_positive
-----------------
.. autodata:: numpyro.distributions.constraints.softplus_positive

softplus_lower_cholesky
-----------------------
.. autodata:: numpyro.distributions.constraints.softplus_lower_cholesky

simplex
-------
.. autodata:: numpyro.distributions.constraints.simplex
Expand Down
2 changes: 2 additions & 0 deletions docs/source/mcmc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ MCMC Kernels

.. autofunction:: numpyro.infer.hmc.hmc.sample_kernel

.. autofunction:: numpyro.infer.hmc_gibbs.taylor_proxy

.. autodata:: numpyro.infer.barker.BarkerMHState

.. autodata:: numpyro.infer.hmc.HMCState
Expand Down
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.5.0")
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples')
parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps')
Expand Down
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Non-centered reparameterization example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model')
parser.add_argument('--num-categories', default=3, type=int)
parser.add_argument('--num-words', default=10, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def body_fn(i, val):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description='Predator-Prey Model')
parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/proportion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description='Testing whether ')
parser.add_argument('-n', '--num-samples', nargs='?', default=500, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument('-n', '--num-samples', nargs='?', default=600, type=int)
parser.add_argument('--num-warmup', nargs='?', default=600, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC')
parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def reconstruct_img(epoch, rng_key):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
assert numpyro.__version__.startswith('0.6.0')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=15, type=int, help='number of training epochs')
parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats('svg')\n",
"\n",
"assert numpyro.__version__.startswith('0.5.0')"
"assert numpyro.__version__.startswith('0.6.0')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.examples.datasets import COVTYPE, load_dataset\n",
"from numpyro.infer import HMC, MCMC, NUTS\n",
"assert numpyro.__version__.startswith('0.5.0')\n",
"assert numpyro.__version__.startswith('0.6.0')\n",
"\n",
"# NB: replace gpu by cpu to run this notebook in cpu\n",
"numpyro.set_platform(\"gpu\")"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/ordinal_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"from numpyro.infer import MCMC, NUTS\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"assert numpyro.__version__.startswith('0.5.0')"
"assert numpyro.__version__.startswith('0.6.0')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion numpyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
plate_stack,
prng_key,
sample,
subsample,
subsample
)
from numpyro.util import enable_x64, set_host_device_count, set_platform
from numpyro.version import __version__
Expand Down
3 changes: 1 addition & 2 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np

from jax.dtypes import canonicalize_dtype
import jax.numpy as jnp
from tensorflow_probability.substrates.jax import bijectors as tfb
from tensorflow_probability.substrates.jax import distributions as tfd
Expand Down Expand Up @@ -162,7 +161,7 @@ class OneHotCategorical(tfd.OneHotCategorical, TFPDistributionMixin):

def enumerate_support(self, expand=True):
n = self.event_shape[-1]
values = jnp.identity(n, dtype=canonicalize_dtype(self.dtype))
values = jnp.identity(n, dtype=jnp.result_type(self.dtype))
values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
Expand Down
15 changes: 7 additions & 8 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
binomial,
categorical,
clamp_probs,
get_dtype,
is_prng_key,
lazy_property,
multinomial,
Expand All @@ -66,7 +65,7 @@ def _to_probs_multinom(logits):


def _to_logits_multinom(probs):
minval = jnp.finfo(get_dtype(probs)).min
minval = jnp.finfo(jnp.result_type(probs)).min
return jnp.clip(jnp.log(probs), a_min=minval)


Expand Down Expand Up @@ -292,11 +291,11 @@ def logits(self):

@property
def mean(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs))
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs))

@property
def variance(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs))
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs))

@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
Expand Down Expand Up @@ -340,11 +339,11 @@ def probs(self):

@property
def mean(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.logits))
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits))

@property
def variance(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.logits))
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits))

@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
Expand Down Expand Up @@ -609,7 +608,7 @@ def __init__(self, probs, validate_args=None):
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
probs = self.probs
dtype = get_dtype(probs)
dtype = jnp.result_type(probs)
shape = sample_shape + self.batch_shape
u = random.uniform(key, shape, dtype)
return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
Expand Down Expand Up @@ -649,7 +648,7 @@ def probs(self):
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
logits = self.logits
dtype = get_dtype(logits)
dtype = jnp.result_type(logits)
shape = sample_shape + self.batch_shape
u = random.uniform(key, shape, dtype)
return jnp.floor(jnp.log1p(-u) / -softplus(logits))
Expand Down
13 changes: 3 additions & 10 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@
import numpy as np

from jax import lax, ops, tree_flatten, tree_map, vmap
from jax.dtypes import canonicalize_dtype
from jax.flatten_util import ravel_pytree
from jax.nn import softplus
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import expit, logit

from numpyro.distributions import constraints
from numpyro.distributions.util import (
get_dtype,
matrix_to_tril_vec,
signed_stick_breaking_tril,
sum_rightmost,
vec_to_tril_matrix
)
from numpyro.distributions.util import matrix_to_tril_vec, signed_stick_breaking_tril, sum_rightmost, vec_to_tril_matrix
from numpyro.util import not_jax_tracer

__all__ = [
Expand Down Expand Up @@ -51,7 +44,7 @@


def _clipped_expit(x):
finfo = jnp.finfo(get_dtype(x))
finfo = jnp.finfo(jnp.result_type(x))
return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps)


Expand Down Expand Up @@ -654,7 +647,7 @@ def __call__(self, x):

def _inverse(self, y):
size = self.permutation.size
permutation_inv = ops.index_update(jnp.zeros(size, dtype=canonicalize_dtype(jnp.int64)),
permutation_inv = ops.index_update(jnp.zeros(size, dtype=jnp.result_type(int)),
self.permutation,
jnp.arange(size))
return y[..., permutation_inv]
Expand Down
9 changes: 2 additions & 7 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np

from jax import jit, lax, random, vmap
from jax.dtypes import canonicalize_dtype
from jax.lib import xla_bridge
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
Expand Down Expand Up @@ -253,10 +252,6 @@ def promote_shapes(*args, shape=()):
if len(s) < num_dims else arg for arg, s in zip(args, shapes)]


def get_dtype(x):
return canonicalize_dtype(lax.dtype(x))


def sum_rightmost(x, dim):
"""
Sum out ``dim`` many rightmost dimensions of a given tensor.
Expand Down Expand Up @@ -351,7 +346,7 @@ def logmatmulexp(x, y):


def clamp_probs(probs):
finfo = jnp.finfo(get_dtype(probs))
finfo = jnp.finfo(jnp.result_type(probs))
return jnp.clip(probs, a_min=finfo.tiny, a_max=1. - finfo.eps)


Expand Down Expand Up @@ -392,7 +387,7 @@ def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64):
:return: centered samples from von Mises
"""
shape = shape or jnp.shape(concentration)
dtype = canonicalize_dtype(dtype)
dtype = jnp.result_type(dtype)
concentration = lax.convert_element_type(concentration, dtype)
concentration = jnp.broadcast_to(concentration, shape)
return _von_mises_centered(key, concentration, shape, dtype)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from numpyro.infer.barker import BarkerMH
from numpyro.infer.elbo import ELBO, RenyiELBO, Trace_ELBO, TraceMeanField_ELBO
from numpyro.infer.hmc import HMC, NUTS
from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs
Expand All @@ -11,7 +12,6 @@
init_to_uniform,
init_to_value
)
from numpyro.infer.barker import BarkerMH
from numpyro.infer.mcmc import MCMC
from numpyro.infer.mixed_hmc import MixedHMC
from numpyro.infer.sa import SA
Expand Down
Loading

0 comments on commit ecd6255

Please sign in to comment.