Skip to content

Commit

Permalink
Release patch for python 3.7 incompatibility (#490)
Browse files Browse the repository at this point in the history
* Release patch for python 3.7 incompatibility

* limit size of cache
  • Loading branch information
neerajprad authored and fehiepsi committed Dec 5, 2019
1 parent 63d68b2 commit 5f1723e
Show file tree
Hide file tree
Showing 18 changed files with 29 additions and 21 deletions.
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples')
parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")')
Expand Down
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="Non-centered reparameterization example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model')
parser.add_argument('--num-categories', default=3, type=int)
parser.add_argument('--num-words', default=10, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def body_fn(i, val):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument('-n', '--num-samples', nargs='?', default=10000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument('-n', '--num-samples', nargs='?', default=600, type=int)
parser.add_argument('--num-warmup', nargs='?', default=600, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC')
parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def reconstruct_img(epoch, rng_key):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.2')
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=20, type=int, help='number of training epochs')
parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats('svg')\n",
"\n",
"assert numpyro.__version__.startswith('0.2.2')"
"assert numpyro.__version__.startswith('0.2.3')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.examples.datasets import COVTYPE, load_dataset\n",
"from numpyro.infer import HMC, MCMC, NUTS\n",
"assert numpyro.__version__.startswith('0.2.2')\n",
"assert numpyro.__version__.startswith('0.2.3')\n",
"\n",
"# NB: replace gpu by cpu to run this notebook in cpu\n",
"numpyro.set_platform(\"gpu\")"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats('svg')\n",
"\n",
"assert numpyro.__version__.startswith('0.2.2')"
"assert numpyro.__version__.startswith('0.2.3')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col
return_last_val=True,
collection_size=self._collection_params["collection_size"],
progbar_desc=functools.partial(get_progbar_desc_str,
num_warmup=lower_idx),
lower_idx),
diagnostics_fn=diagnostics)
states, last_val = collect_vals
# Get first argument of type `HMCState`
Expand Down
16 changes: 12 additions & 4 deletions numpyro/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import namedtuple
from collections import namedtuple, OrderedDict
from contextlib import contextmanager
import os
import random
Expand Down Expand Up @@ -143,13 +143,21 @@ def identity(x):


def cached_by(outer_fn, *keys):
outer_fn._cache = getattr(outer_fn, '_cache', {})
# Restrict cache size to prevent ref cycles.
max_size = 8
outer_fn._cache = getattr(outer_fn, '_cache', OrderedDict())

def _wrapped(fn):
fn_cache = outer_fn._cache
if keys in fn_cache:
return fn_cache[keys]
fn_cache[keys] = fn
fn = fn_cache[keys]
# update position
del fn_cache[keys]
fn_cache[keys] = fn
else:
fn_cache[keys] = fn
if len(fn_cache) > max_size:
fn_cache.popitem(last=False)
return fn

return _wrapped
Expand Down
2 changes: 1 addition & 1 deletion numpyro/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.2'
__version__ = '0.2.3'

0 comments on commit 5f1723e

Please sign in to comment.