Skip to content

Commit

Permalink
add missing docs and remove some TODOs (#529)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Jan 23, 2020
1 parent e6b2027 commit 2e31678
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 5 deletions.
8 changes: 8 additions & 0 deletions docs/source/autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ AutoContinuous
:show-inheritance:
:member-order: bysource

AutoBNAFNormal
--------------
.. autoclass:: numpyro.contrib.autoguide.AutoBNAFNormal
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

AutoDiagonalNormal
------------------
.. autoclass:: numpyro.contrib.autoguide.AutoDiagonalNormal
Expand Down
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,11 @@ InverseAutoregressiveTransform
:undoc-members:
:show-inheritance:
:member-order: bysource

BlockNeuralAutoregressiveTransform
----------------------------------
.. autoclass:: numpyro.distributions.flows.BlockNeuralAutoregressiveTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
8 changes: 8 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ condition
:show-inheritance:
:member-order: bysource

mask
----
.. autoclass:: numpyro.handlers.mask
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

replay
------
.. autoclass:: numpyro.handlers.replay
Expand Down
4 changes: 4 additions & 0 deletions docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ plate
-----
.. autoclass:: numpyro.primitives.plate

deterministic
-------------
.. autofunction:: numpyro.primitives.deterministic

factor
------
.. autofunction:: numpyro.primitives.factor
Expand Down
1 change: 0 additions & 1 deletion numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def categorical(key, p, shape=()):
return _categorical(key, p, shape)


# TODO: use this sampler in CategoricalLogits
# TODO: drop this for the next JAX release, see https://github.com/google/jax/pull/1855
def categorical_logits(key, logits, shape=()):
shape = shape or logits.shape[:-1]
Expand Down
3 changes: 3 additions & 0 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,9 @@ class SA(MCMCKernel):
subset of approximate posterior samples of size num_chains x num_samples
instead of num_chains x num_samples x adapt_state_size.
.. note:: We recommend to use this kernel with `progress_bar=False` in :class:`MCMC`
to reduce JAX's dispatch overhead.
**References:**
1. *Sample Adaptive MCMC* (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc),
Expand Down
2 changes: 1 addition & 1 deletion numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def deterministic(name, value):
"""
Used to designate deterministic sites in the model. Note that most effect
handlers will not operate on deterministic sites (except
:function:`~numpyro.handlers.trace`), so deterministic sites should be
:func:`~numpyro.handlers.trace`), so deterministic sites should be
side-effect free. The use case for deterministic nodes is to record any
values in the model execution trace.
Expand Down
3 changes: 0 additions & 3 deletions test/test_hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,6 @@ def fn(vv_state):
assert tree.num_proposals > 10


# TODO: raise this warning issue upstream, the issue is at this line
# https://github.com/google/jax/blob/master/jax/numpy/lax_numpy.py#L2732
@pytest.mark.filterwarnings('ignore:Explicitly requested dtype float64')
@pytest.mark.parametrize('method', [consensus, parametric_draws])
@pytest.mark.parametrize('diagonal', [True, False])
def test_gaussian_subposterior(method, diagonal):
Expand Down

0 comments on commit 2e31678

Please sign in to comment.