Skip to content

Commit

Permalink
add diverging to extra_fields by default (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Nov 1, 2019
1 parent a2e5990 commit 57424f2
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 16 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Let us infer the values of the unknown parameters in our model by running MCMC u
We can print the summary of the MCMC run, and examine if we observed any divergences during inference:

```python
mcmc.print_summary()
>>> mcmc.print_summary()

mean std median 5.0% 95.0% n_eff r_hat
mu 3.94 2.81 3.16 0.03 9.28 114.51 1.06
Expand All @@ -69,8 +69,6 @@ mcmc.print_summary()
theta[6] 5.74 4.67 4.34 -1.92 13.25 58.42 1.05
theta[7] 4.29 4.63 3.23 -2.14 12.37 342.50 1.02

>>> print("Number of divergences: {}".format(sum(mcmc.get_extra_fields()['diverging'])))

Number of divergences: 139
```

Expand Down Expand Up @@ -104,8 +102,6 @@ The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates tha
theta[5] 3.92 4.43 4.06 -2.41 11.09 1179.74 1.00
theta[6] 5.88 4.84 5.34 -1.45 13.11 881.38 1.00
theta[7] 4.63 4.86 4.64 -3.57 11.80 1065.27 1.00

>>> print("Number of divergences: {}".format(sum(mcmc.get_extra_fields()['diverging'])))

Number of divergences: 0
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ HPDI

Summary
-------
.. autofunction:: numpyro.diagnostics.summary
.. autofunction:: numpyro.diagnostics.print_summary
4 changes: 2 additions & 2 deletions examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpyro
from numpyro import optim
from numpyro.contrib.autoguide import AutoContinuousELBO, AutoIAFNormal
from numpyro.diagnostics import summary
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, SVI
Expand Down Expand Up @@ -90,7 +90,7 @@ def main(args):
zs = mcmc.get_samples()
print("Transform samples into unwarped space...")
samples = vmap(transformed_constrain_fn)(zs)
summary(tree_map(lambda x: x[None, ...], samples))
print_summary(tree_map(lambda x: x[None, ...], samples))
samples = samples['x'].copy()

# make plots
Expand Down
8 changes: 4 additions & 4 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
'gelman_rubin',
'hpdi',
'split_gelman_rubin',
'summary',
'print_summary',
]


Expand Down Expand Up @@ -212,7 +212,7 @@ def hpdi(x, prob=0.90, axis=0):
return onp.concatenate([hpd_left, hpd_right], axis=axis)


def summary(samples, prob=0.90, group_by_chain=True):
def print_summary(samples, prob=0.90, group_by_chain=True):
"""
Prints a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Expand Down Expand Up @@ -241,7 +241,7 @@ def summary(samples, prob=0.90, group_by_chain=True):
header_format = name_format + ' {:>9} {:>9} {:>9} {:>9} {:>9} {:>9} {:>9}'
columns = ['', 'mean', 'std', 'median', '{:.1f}%'.format(50 * (1 - prob)),
'{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'r_hat']
print('\n')
print()
print(header_format.format(*columns))

# XXX: consider to expose digits, depending on user requests
Expand All @@ -264,4 +264,4 @@ def summary(samples, prob=0.90, group_by_chain=True):
idx_str = '[{}]'.format(','.join(map(str, idx)))
print(row_format.format(name + idx_str, mean[idx], sd[idx], median[idx],
hpd[0][idx], hpd[1][idx], n_eff[idx], r_hat[idx]))
print('\n')
print()
11 changes: 7 additions & 4 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jax.random import PRNGKey
from jax.tree_util import tree_flatten, tree_map, tree_multimap

from numpyro.diagnostics import summary
from numpyro.diagnostics import print_summary
from numpyro.infer.hmc_util import (
IntegratorState,
build_tree,
Expand Down Expand Up @@ -598,10 +598,10 @@ def _single_chain_mcmc(self, init, collect_fields=('z',), collect_warmup=False,
if len(collect_fields) == 1:
states = (states,)
states = dict(zip(collect_fields, states))
states['z'] = vmap(constrain_fn)(states['z']) if len(tree_flatten(states)[0]) > 0 else states['z']
states['z'] = vmap(constrain_fn)(states['z']) if len(tree_flatten(states['z'])[0]) > 0 else states['z']
return states

def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs):
def run(self, rng_key, *args, extra_fields=('diverging',), collect_warmup=False, init_params=None, **kwargs):
"""
Run the MCMC samplers and collect samples.
Expand Down Expand Up @@ -693,4 +693,7 @@ def get_extra_fields(self, group_by_chain=False):
return {k: v for k, v in states.items() if k != 'z'}

def print_summary(self, prob=0.9):
summary(self._states['z'], prob=prob)
print_summary(self._states['z'], prob=prob)
extra_fields = self.get_extra_fields()
if 'diverging' in extra_fields:
print("Number of divergences: {}".format(np.sum(extra_fields['diverging'])))

0 comments on commit 57424f2

Please sign in to comment.