Skip to content

Commit

Permalink
Bump up version to 0.2.1 (#435)
Browse files Browse the repository at this point in the history
* Bump up version to 0.2.1

* address comments

* use tuple instead of set
  • Loading branch information
neerajprad authored and fehiepsi committed Nov 1, 2019
1 parent 57424f2 commit cd53186
Show file tree
Hide file tree
Showing 20 changed files with 52 additions and 36 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dist: xenial

install:
- pip install -U pip
# Keep track of pyro-api master branch
- pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
- pip install jaxlib
- pip install jax
- pip install .[examples,test]
Expand Down
46 changes: 30 additions & 16 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples')
parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")')
Expand Down
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="Non-centered reparameterization example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model')
parser.add_argument('--num-categories', default=3, type=int)
parser.add_argument('--num-words', default=10, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def body_fn(i, val):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument('-n', '--num-samples', nargs='?', default=10000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument('-n', '--num-samples', nargs='?', default=3000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC')
parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def reconstruct_img(epoch, rng_key):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.0')
assert numpyro.__version__.startswith('0.2.1')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=20, type=int, help='number of training epochs')
parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"\n",
"plt.style.use('bmh')\n",
"\n",
"assert numpyro.__version__.startswith('0.2.0')"
"assert numpyro.__version__.startswith('0.2.1')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.examples.datasets import COVTYPE, load_dataset\n",
"from numpyro.infer import HMC, MCMC, NUTS\n",
"assert numpyro.__version__.startswith('0.2.0')\n",
"assert numpyro.__version__.startswith('0.2.1')\n",
"\n",
"# NB: replace gpu by cpu to run this notebook in cpu\n",
"numpyro.set_platform(\"gpu\")"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'svg'\n",
"\n",
"assert numpyro.__version__.startswith('0.2.0')"
"assert numpyro.__version__.startswith('0.2.1')"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,14 +601,14 @@ def _single_chain_mcmc(self, init, collect_fields=('z',), collect_warmup=False,
states['z'] = vmap(constrain_fn)(states['z']) if len(tree_flatten(states['z'])[0]) > 0 else states['z']
return states

def run(self, rng_key, *args, extra_fields=('diverging',), collect_warmup=False, init_params=None, **kwargs):
def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs):
"""
Run the MCMC samplers and collect samples.
:param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
:param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method.
These are typically the arguments needed by the `model`.
:param extra_fields: Extra fields (aside from `z`) from :data:`numpyro.infer.mcmc.HMCState`
:param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState`
to collect during the MCMC run.
:type extra_fields: tuple or list
:param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults
Expand All @@ -635,7 +635,7 @@ def run(self, rng_key, *args, extra_fields=('diverging',), collect_warmup=False,
raise ValueError('`init_params` must have the same leading dimension'
' as `num_chains`.')
assert isinstance(extra_fields, (tuple, list))
collect_fields = ('z',) + tuple(extra_fields) if 'z' not in extra_fields else extra_fields
collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields)))
if self.num_chains == 1:
states_flat = self._single_chain_mcmc((rng_key, init_params), collect_fields, collect_warmup,
args, kwargs)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.0'
__version__ = '0.2.1'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
'test': [
'flake8',
'pytest>=4.1',
'pyro-api@https://api.github.com/repos/pyro-ppl/pyro-api/tarball/master'
'pyro-api>=0.1.1'
],
'dev': ['ipython'],
'examples': ['matplotlib'],
Expand Down

0 comments on commit cd53186

Please sign in to comment.