From 5f1723e935bff7ee5a67530a181e6896b69d63dd Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 4 Dec 2019 22:10:25 -0800 Subject: [PATCH] Release patch for python 3.7 incompatibility (#490) * Release patch for python 3.7 incompatibility * limit size of cache --- examples/baseball.py | 2 +- examples/bnn.py | 2 +- examples/covtype.py | 2 +- examples/funnel.py | 2 +- examples/gp.py | 2 +- examples/hmm.py | 2 +- examples/minipyro.py | 2 +- examples/neutra.py | 2 +- examples/sparse_regression.py | 2 +- examples/stochastic_volatility.py | 2 +- examples/ucbadmit.py | 2 +- examples/vae.py | 2 +- notebooks/source/bayesian_regression.ipynb | 2 +- notebooks/source/logistic_regression.ipynb | 2 +- notebooks/source/time_series_forecasting.ipynb | 2 +- numpyro/infer/mcmc.py | 2 +- numpyro/util.py | 16 ++++++++++++---- numpyro/version.py | 2 +- 18 files changed, 29 insertions(+), 21 deletions(-) diff --git a/examples/baseball.py b/examples/baseball.py index b0478600e..6f4711f98 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -182,7 +182,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="Baseball batting average using HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1500, type=int) diff --git a/examples/bnn.py b/examples/bnn.py index 731c78d74..56a4f8d96 100644 --- a/examples/bnn.py +++ b/examples/bnn.py @@ -133,7 +133,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/covtype.py b/examples/covtype.py index de1eba4f4..332a59d91 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -57,7 +57,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples') parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') diff --git a/examples/funnel.py b/examples/funnel.py index 796e510f3..a06514bd7 100644 --- a/examples/funnel.py +++ b/examples/funnel.py @@ -84,7 +84,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="Non-centered reparameterization example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/gp.py b/examples/gp.py index d2c55fe72..4171945f7 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -126,7 +126,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/hmm.py b/examples/hmm.py index 023b8c020..96450c47d 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -180,7 +180,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model') parser.add_argument('--num-categories', default=3, type=int) parser.add_argument('--num-words', default=10, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index 70fd61b66..37411cdd3 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -55,7 +55,7 @@ def body_fn(i, val): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-f", "--full-pyro", action="store_true", default=False) parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index 388fc650e..0c529a2c8 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -151,7 +151,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="NeuTra HMC") parser.add_argument('-n', '--num-samples', nargs='?', default=10000, type=int) parser.add_argument('--num-warmup', nargs='?', default=1000, type=int) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 895660edc..35be77bb8 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -319,7 +319,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=500, type=int) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index f93a21131..da647fbdb 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -108,7 +108,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="Stochastic Volatility Model") parser.add_argument('-n', '--num-samples', nargs='?', default=600, type=int) parser.add_argument('--num-warmup', nargs='?', default=600, type=int) diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index 839371b9c..569073c4d 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -126,7 +126,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC') parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int) parser.add_argument('--num-warmup', nargs='?', default=500, type=int) diff --git a/examples/vae.py b/examples/vae.py index 3e20f57e6..76a82377c 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -128,7 +128,7 @@ def reconstruct_img(epoch, rng_key): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.2.2') + assert numpyro.__version__.startswith('0.2.3') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=20, type=int, help='number of training epochs') parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate') diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 97e3f213f..e28848654 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -66,7 +66,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats('svg')\n", "\n", - "assert numpyro.__version__.startswith('0.2.2')" + "assert numpyro.__version__.startswith('0.2.3')" ] }, { diff --git a/notebooks/source/logistic_regression.ipynb b/notebooks/source/logistic_regression.ipynb index 081341b57..67640e3d0 100644 --- a/notebooks/source/logistic_regression.ipynb +++ b/notebooks/source/logistic_regression.ipynb @@ -31,7 +31,7 @@ "import numpyro.distributions as dist\n", "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", - "assert numpyro.__version__.startswith('0.2.2')\n", + "assert numpyro.__version__.startswith('0.2.3')\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index d24461f5d..3ccd28624 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -39,7 +39,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats('svg')\n", "\n", - "assert numpyro.__version__.startswith('0.2.2')" + "assert numpyro.__version__.startswith('0.2.3')" ] }, { diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 2a9572c45..93d4ab921 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -774,7 +774,7 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col return_last_val=True, collection_size=self._collection_params["collection_size"], progbar_desc=functools.partial(get_progbar_desc_str, - num_warmup=lower_idx), + lower_idx), diagnostics_fn=diagnostics) states, last_val = collect_vals # Get first argument of type `HMCState` diff --git a/numpyro/util.py b/numpyro/util.py index 4e506ae73..696f1854c 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -1,4 +1,4 @@ -from collections import namedtuple +from collections import namedtuple, OrderedDict from contextlib import contextmanager import os import random @@ -143,13 +143,21 @@ def identity(x): def cached_by(outer_fn, *keys): - outer_fn._cache = getattr(outer_fn, '_cache', {}) + # Restrict cache size to prevent ref cycles. + max_size = 8 + outer_fn._cache = getattr(outer_fn, '_cache', OrderedDict()) def _wrapped(fn): fn_cache = outer_fn._cache if keys in fn_cache: - return fn_cache[keys] - fn_cache[keys] = fn + fn = fn_cache[keys] + # update position + del fn_cache[keys] + fn_cache[keys] = fn + else: + fn_cache[keys] = fn + if len(fn_cache) > max_size: + fn_cache.popitem(last=False) return fn return _wrapped diff --git a/numpyro/version.py b/numpyro/version.py index 020ed73d7..d93b5b242 100644 --- a/numpyro/version.py +++ b/numpyro/version.py @@ -1 +1 @@ -__version__ = '0.2.2' +__version__ = '0.2.3'