Skip to content

Commit

Permalink
Ensure BlockNeuralAutoregressiveTransform maps real->real (#1897)
Browse files Browse the repository at this point in the history
* fix bnaf

* Docs and test

* Docs: add missing param

* Docs update
  • Loading branch information
danielward27 authored Dec 5, 2024
1 parent b74c0e9 commit 7824aa3
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 6 deletions.
41 changes: 39 additions & 2 deletions numpyro/nn/block_neural_arn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,32 @@ def apply_fun(params, inputs, **kwargs):
return init_fun, apply_fun


def LeakyTanh(min_grad: float = 0.01):
"""
Leaky Tanh nonlinearity :math:`y=\text{tanh}(x) + cx` with its log-Jacobian.
This choice when used in ``BlockNeuralAutoregressiveNN`` ensures the image of the
transformation is the set of real values (unlike ``Tanh``).
:param float min_grad: The minimum gradient value (:math:`c` above). Defaults to 0.01.
:return: an (`init_fn`, `apply_fn`) pair.
"""

def init_fun(rng, input_shape):
return input_shape, ()

def apply_fun(params, inputs, **kwargs):
x, logdet = inputs
out = jnp.tanh(x) + min_grad * x # ensure grad at least 0.01.
tanh_logdet = -2 * (x + softplus(-2 * x) - jnp.log(2.0))
act_logdet = jnp.logaddexp(tanh_logdet, jnp.log(min_grad))
# Reshape to match logdet shape
act_logdet = act_logdet.reshape(logdet.shape[:-2] + (1, logdet.shape[-1]))
return out, logdet + act_logdet

return init_fun, apply_fun


def FanInResidualNormal():
"""
Similar to stax.FanInSum but also keeps track of log determinant of Jacobian.
Expand Down Expand Up @@ -156,10 +182,19 @@ def apply_fun(params, inputs, **kwargs):
return init_fun, apply_fun


def BlockNeuralAutoregressiveNN(input_dim, hidden_factors=[8, 8], residual=None):
def BlockNeuralAutoregressiveNN(
input_dim,
hidden_factors=[8, 8],
residual=None,
activation=None,
):
"""
An implementation of Block Neural Autoregressive neural network.
In contrast to the original paper, by default, we use ``LeakyTanh`` as the
activation, defined as :math:`y=Tanh(x) + cx` with :math:`c` being a small constant
(default to 0.01), which ensures the transform maps from real -> real.
**References**
1. *Block Neural Autoregressive Flow*,
Expand All @@ -170,13 +205,15 @@ def BlockNeuralAutoregressiveNN(input_dim, hidden_factors=[8, 8], residual=None)
input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1].
The elements of hidden_factors must be integers.
:param str residual: Type of residual connections to use. One of `None`, `"normal"`, `"gated"`.
:param tuple activation: An (`init_fn`, `update_fn`) pair. Defaults to ``LeakyTanh``.
:return: an (`init_fn`, `update_fn`) pair.
"""
layers = []
in_factor = 1
activation = LeakyTanh() if activation is None else activation
for hidden_factor in hidden_factors:
layers.append(BlockMaskedDense(input_dim, in_factor, hidden_factor))
layers.append(Tanh())
layers.append(activation)
in_factor = hidden_factor
layers.append(BlockMaskedDense(input_dim, in_factor, 1))
arn = stax.serial(*layers)
Expand Down
48 changes: 44 additions & 4 deletions test/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@

from jax import jacfwd, random
from jax.example_libraries import stax
import jax.numpy as jnp
import jax.random as jr

from numpyro.distributions import Normal, TransformedDistribution
from numpyro.distributions.flows import (
BlockNeuralAutoregressiveTransform,
InverseAutoregressiveTransform,
)
from numpyro.distributions.util import matrix_to_tril_vec
from numpyro.nn import AutoregressiveNN, BlockNeuralAutoregressiveNN
from numpyro.nn.block_neural_arn import LeakyTanh, Tanh


def _make_iaf_args(input_dim, hidden_dims):
Expand All @@ -34,8 +38,12 @@ def _make_iaf_args(input_dim, hidden_dims):
return (partial(arn, init_params),)


def _make_bnaf_args(input_dim, hidden_factors):
arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors)
def _make_bnaf_args(input_dim, hidden_factors, activation):
arn_init, arn = BlockNeuralAutoregressiveNN(
input_dim,
hidden_factors,
activation=activation,
)
_, rng_key_perm = random.split(random.PRNGKey(0))
_, init_params = arn_init(random.PRNGKey(0), (input_dim,))
return (partial(arn, init_params),)
Expand All @@ -46,10 +54,24 @@ def _make_bnaf_args(input_dim, hidden_factors):
[
(InverseAutoregressiveTransform, _make_iaf_args(5, hidden_dims=[10]), 5),
(InverseAutoregressiveTransform, _make_iaf_args(7, hidden_dims=[8, 9]), 7),
(BlockNeuralAutoregressiveTransform, _make_bnaf_args(7, hidden_factors=[4]), 7),
(
BlockNeuralAutoregressiveTransform,
_make_bnaf_args(7, hidden_factors=[2, 3]),
_make_bnaf_args(7, hidden_factors=[4], activation=LeakyTanh()),
7,
),
(
BlockNeuralAutoregressiveTransform,
_make_bnaf_args(7, hidden_factors=[2, 3], activation=LeakyTanh()),
7,
),
(
BlockNeuralAutoregressiveTransform,
_make_bnaf_args(7, hidden_factors=[4], activation=Tanh()),
7,
),
(
BlockNeuralAutoregressiveTransform,
_make_bnaf_args(7, hidden_factors=[2, 3], activation=Tanh()),
7,
),
],
Expand Down Expand Up @@ -91,3 +113,21 @@ def test_flows(flow_class, flow_args, input_dim, batch_shape):

assert np.sum(np.abs(np.triu(jac, 1))) == 0.00
assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)


def test_bnaf_normalization():
dim = (1,)
x = jnp.linspace(-1000, 1000, 5000)[:, None]

init_fn, apply_fn = BlockNeuralAutoregressiveNN(dim[0], activation=LeakyTanh(0.1))
params = init_fn(jr.PRNGKey(0), (1,))[1]
arn = partial(apply_fn, params)
bnaf = BlockNeuralAutoregressiveTransform(arn)
dist = TransformedDistribution(Normal(jnp.zeros(dim), 0.5), bnaf.inv)
probs = jnp.exp(dist.log_prob(x))
probs, x = jnp.squeeze(probs), jnp.squeeze(x)

# Rough integral
integral = jnp.trapezoid(probs, x)
assert integral > 0.9
assert integral < 1.1

0 comments on commit 7824aa3

Please sign in to comment.