Skip to content

Commit

Permalink
Add Gaussian state space model distribution. (#1904)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann authored Nov 12, 2024
1 parent c8a0990 commit a313a6e
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ GaussianRandomWalk
:show-inheritance:
:member-order: bysource

GaussianStateSpace
^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.GaussianStateSpace
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Gompertz
^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.Gompertz
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Exponential,
Gamma,
GaussianRandomWalk,
GaussianStateSpace,
Gompertz,
Gumbel,
HalfCauchy,
Expand Down Expand Up @@ -145,6 +146,7 @@
"GaussianCopula",
"GaussianCopulaBeta",
"GaussianRandomWalk",
"GaussianStateSpace",
"Geometric",
"GeometricLogits",
"GeometricProbs",
Expand Down
104 changes: 104 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
CorrMatrixCholeskyTransform,
ExpTransform,
PowerTransform,
RecursiveLinearTransform,
SigmoidTransform,
ZeroSumTransform,
)
Expand Down Expand Up @@ -550,6 +551,109 @@ def __init__(self, df, *, validate_args=None):
super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)


class GaussianStateSpace(TransformedDistribution):
r"""
Gaussian state space model.
.. math::
\mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\
&=\sum_{k=1} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_t,
where :math:`\mathbf{z}_t` is the state vector at step :math:`t`, :math:`\mathbf{A}`
is the transition matrix, and :math:`\boldsymbol\epsilon` is the innovation noise.
:param num_steps: Number of steps.
:param transition_matrix: State space transition matrix :math:`\mathbf{A}`.
:param covariance_matrix: Covariance of the innovation noise
:math:`\boldsymbol\epsilon`.
:param precision_matrix: Precision matrix of the innovation noise
:math:`\boldsymbol\epsilon`.
:param scale_tril: Scale matrix of the innovation noise
:math:`\boldsymbol\epsilon`.
"""

arg_constraints = {
"covariance_matrix": constraints.positive_definite,
"precision_matrix": constraints.positive_definite,
"scale_tril": constraints.lower_cholesky,
"transition_matrix": constraints.real_matrix,
}
support = constraints.real_matrix
pytree_aux_fields = ("num_steps",)

def __init__(
self,
num_steps,
transition_matrix,
covariance_matrix=None,
precision_matrix=None,
scale_tril=None,
*,
validate_args=None,
):
assert (
isinstance(num_steps, int) and num_steps > 0
), "`num_steps` argument should be an positive integer."
self.num_steps = num_steps
assert (
transition_matrix.ndim == 2
), "`transition_matrix` argument should be a square matrix"
self.transition_matrix = transition_matrix
# Expand the covariance/presicion/scale matrices to the right number of steps.
args = {
"covariance_matrix": covariance_matrix,
"precision_matrix": precision_matrix,
"scale_tril": scale_tril,
}
args = {
key: jnp.expand_dims(value, axis=-3).repeat(num_steps, axis=-3)
for key, value in args.items()
if value is not None
}
base_distribution = MultivariateNormal(**args)
self.scale_tril = base_distribution.scale_tril[..., 0, :, :]
base_distribution = base_distribution.to_event(1)
transform = RecursiveLinearTransform(transition_matrix)
super().__init__(base_distribution, transform, validate_args=validate_args)

@property
def mean(self):
# The mean of the base distribution is zero and it has the right shape.
return self.base_dist.mean

@property
def variance(self):
# Given z_t = \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
# vector at step t is E[z_t transpose(z_t)] = \sum_{k,k'}^t A^{t-k}
# E[\epsilon_k transpose(\epsilon_{k'})] transpose(A^{t-k'}). We only have
# contributions for k = k' because innovations at different steps are
# independent such that E[z_t transpose(z_t)] = \sum_k^t A^{t-k} @
# @ covariance_matrix @ transpose(A^{t-k}). Using `scan` is an easy way to
# evaluate this expression.
_, scale_tril = scan(
lambda carry, _: (self.transition_matrix @ carry, carry),
self.scale_tril,
jnp.arange(self.num_steps),
)
return (
jnp.diagonal(scale_tril @ scale_tril.mT, axis1=-1, axis2=-2)
.cumsum(axis=0)
.swapaxes(0, -2)
)

@lazy_property
def covariance_matrix(self):
return self.scale_tril @ self.scale_tril.mT

@lazy_property
def precision_matrix(self):
identity = jnp.broadcast_to(
jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape
)
return cho_solve((self.scale_tril, True), identity)


class GaussianRandomWalk(Distribution):
arg_constraints = {"scale": constraints.positive}
support = constraints.real_vector
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ class RecursiveLinearTransform(Transform):
are vectors and :math:`A` is a square transition matrix. The series is initialized
by :math:`y_0 = 0`.
:param transition_matrix: Squared transition matrix :math:`A` for successive states
:param transition_matrix: Square transition matrix :math:`A` for successive states
or a batch of transition matrices.
**Example:**
Expand Down
35 changes: 34 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,18 @@ def get_sp_dist(jax_dist):
T(dist.Gamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
T(dist.GaussianRandomWalk, 0.1, 10),
T(dist.GaussianRandomWalk, np.array([0.1, 0.3, 0.25]), 10),
T(
dist.GaussianStateSpace,
10,
np.array([[0.8, 0.2], [-0.1, 1.1]]),
np.array([[0.8, 0.2], [0.2, 0.7]]),
),
T(
dist.GaussianStateSpace,
5,
np.array([[0.8, 0.2], [-0.1, 1.1]]),
np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]),
),
T(
dist.GaussianCopulaBeta,
np.array([7.0, 2.0]),
Expand Down Expand Up @@ -1426,6 +1438,7 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params):
if jax_dist.__name__ in (
"EulerMaruyama",
"GaussianRandomWalk",
"GaussianStateSpace",
"_ImproperWrapper",
"LKJ",
"LKJCholesky",
Expand Down Expand Up @@ -2093,7 +2106,10 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
and dist_args[i] == "base_dist"
):
continue
if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps":
if (
issubclass(jax_dist, (dist.GaussianRandomWalk, dist.GaussianStateSpace))
and dist_args[i] == "num_steps"
):
continue
if jax_dist is dist.ZeroSumNormal and dist_args[i] == "event_shape":
continue
Expand Down Expand Up @@ -3477,3 +3493,20 @@ def test_sine_bivariate_von_mises_norm(conc):
jnp.exp(dist.log_prob(mesh)) * (2 * jnp.pi) ** 2 / num_samples**2
).sum()
assert jnp.allclose(integral_torus, 1.0, rtol=1e-2)


def test_gaussian_random_walk_state_space_equivalence():
# Gaussian random walks are state space models with one state and unit transition
# matrix. Here, we verify we get the expected results.
scale = 0.3
num_steps = 4
d1 = dist.GaussianRandomWalk(scale, num_steps)
d2 = dist.GaussianStateSpace(num_steps, jnp.eye(1), scale_tril=scale * jnp.eye(1))
assert jnp.allclose(d1.variance, jnp.squeeze(d2.variance, axis=-1))

key = jax.random.key(18)
x1 = d1.sample(key, (3,))
x2 = d2.sample(key, (3,))
assert jnp.allclose(x1, jnp.squeeze(x2, axis=-1))

assert jnp.allclose(d1.log_prob(x1), d2.log_prob(x2))

0 comments on commit a313a6e

Please sign in to comment.