Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gaussian state space model distribution. #1904

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ GaussianRandomWalk
:show-inheritance:
:member-order: bysource

GaussianStateSpace
^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.GaussianStateSpace
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Gompertz
^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.Gompertz
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Exponential,
Gamma,
GaussianRandomWalk,
GaussianStateSpace,
Gompertz,
Gumbel,
HalfCauchy,
Expand Down Expand Up @@ -145,6 +146,7 @@
"GaussianCopula",
"GaussianCopulaBeta",
"GaussianRandomWalk",
"GaussianStateSpace",
"Geometric",
"GeometricLogits",
"GeometricProbs",
Expand Down
104 changes: 104 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
CorrMatrixCholeskyTransform,
ExpTransform,
PowerTransform,
RecursiveLinearTransform,
SigmoidTransform,
ZeroSumTransform,
)
Expand Down Expand Up @@ -550,6 +551,109 @@ def __init__(self, df, *, validate_args=None):
super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)


class GaussianStateSpace(TransformedDistribution):
r"""
Gaussian state space model.

.. math::
\mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\
&=\sum_{k=1} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_t,

where :math:`\mathbf{z}_t` is the state vector at step :math:`t`, :math:`\mathbf{A}`
is the transition matrix, and :math:`\boldsymbol\epsilon` is the innovation noise.


:param num_steps: Number of steps.
:param transition_matrix: State space transition matrix :math:`\mathbf{A}`.
:param covariance_matrix: Covariance of the innovation noise
:math:`\boldsymbol\epsilon`.
:param precision_matrix: Precision matrix of the innovation noise
:math:`\boldsymbol\epsilon`.
:param scale_tril: Scale matrix of the innovation noise
:math:`\boldsymbol\epsilon`.
"""

arg_constraints = {
"covariance_matrix": constraints.positive_definite,
"precision_matrix": constraints.positive_definite,
"scale_tril": constraints.lower_cholesky,
"transition_matrix": constraints.real_matrix,
}
support = constraints.real_matrix
pytree_aux_fields = ("num_steps",)

def __init__(
self,
num_steps,
transition_matrix,
covariance_matrix=None,
precision_matrix=None,
scale_tril=None,
*,
validate_args=None,
):
assert (
isinstance(num_steps, int) and num_steps > 0
), "`num_steps` argument should be an positive integer."
self.num_steps = num_steps
assert (
transition_matrix.ndim == 2
), "`transition_matrix` argument should be a square matrix"
self.transition_matrix = transition_matrix
# Expand the covariance/presicion/scale matrices to the right number of steps.
args = {
"covariance_matrix": covariance_matrix,
"precision_matrix": precision_matrix,
"scale_tril": scale_tril,
}
args = {
key: jnp.expand_dims(value, axis=-3).repeat(num_steps, axis=-3)
for key, value in args.items()
if value is not None
}
base_distribution = MultivariateNormal(**args)
self.scale_tril = base_distribution.scale_tril[..., 0, :, :]
base_distribution = base_distribution.to_event(1)
transform = RecursiveLinearTransform(transition_matrix)
super().__init__(base_distribution, transform, validate_args=validate_args)

@property
def mean(self):
# The mean of the base distribution is zero and it has the right shape.
return self.base_dist.mean

@property
def variance(self):
# Given z_t = \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
# vector at step t is E[z_t transpose(z_t)] = \sum_{k,k'}^t A^{t-k}
# E[\epsilon_k transpose(\epsilon_{k'})] transpose(A^{t-k'}). We only have
# contributions for k = k' because innovations at different steps are
# independent such that E[z_t transpose(z_t)] = \sum_k^t A^{t-k} @
# @ covariance_matrix @ transpose(A^{t-k}). Using `scan` is an easy way to
# evaluate this expression.
_, scale_tril = scan(
lambda carry, _: (self.transition_matrix @ carry, carry),
self.scale_tril,
jnp.arange(self.num_steps),
)
return (
jnp.diagonal(scale_tril @ scale_tril.mT, axis1=-1, axis2=-2)
.cumsum(axis=0)
.swapaxes(0, -2)
)

@lazy_property
def covariance_matrix(self):
return self.scale_tril @ self.scale_tril.mT

@lazy_property
def precision_matrix(self):
identity = jnp.broadcast_to(
jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape
)
return cho_solve((self.scale_tril, True), identity)


class GaussianRandomWalk(Distribution):
arg_constraints = {"scale": constraints.positive}
support = constraints.real_vector
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ class RecursiveLinearTransform(Transform):
are vectors and :math:`A` is a square transition matrix. The series is initialized
by :math:`y_0 = 0`.

:param transition_matrix: Squared transition matrix :math:`A` for successive states
:param transition_matrix: Square transition matrix :math:`A` for successive states
or a batch of transition matrices.

**Example:**
Expand Down
35 changes: 34 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,18 @@ def get_sp_dist(jax_dist):
T(dist.Gamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
T(dist.GaussianRandomWalk, 0.1, 10),
T(dist.GaussianRandomWalk, np.array([0.1, 0.3, 0.25]), 10),
T(
dist.GaussianStateSpace,
10,
np.array([[0.8, 0.2], [-0.1, 1.1]]),
np.array([[0.8, 0.2], [0.2, 0.7]]),
),
T(
dist.GaussianStateSpace,
5,
np.array([[0.8, 0.2], [-0.1, 1.1]]),
np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]),
),
T(
dist.GaussianCopulaBeta,
np.array([7.0, 2.0]),
Expand Down Expand Up @@ -1426,6 +1438,7 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params):
if jax_dist.__name__ in (
"EulerMaruyama",
"GaussianRandomWalk",
"GaussianStateSpace",
"_ImproperWrapper",
"LKJ",
"LKJCholesky",
Expand Down Expand Up @@ -2093,7 +2106,10 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
and dist_args[i] == "base_dist"
):
continue
if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps":
if (
issubclass(jax_dist, (dist.GaussianRandomWalk, dist.GaussianStateSpace))
and dist_args[i] == "num_steps"
):
continue
if jax_dist is dist.ZeroSumNormal and dist_args[i] == "event_shape":
continue
Expand Down Expand Up @@ -3477,3 +3493,20 @@ def test_sine_bivariate_von_mises_norm(conc):
jnp.exp(dist.log_prob(mesh)) * (2 * jnp.pi) ** 2 / num_samples**2
).sum()
assert jnp.allclose(integral_torus, 1.0, rtol=1e-2)


def test_gaussian_random_walk_state_space_equivalence():
# Gaussian random walks are state space models with one state and unit transition
# matrix. Here, we verify we get the expected results.
scale = 0.3
num_steps = 4
d1 = dist.GaussianRandomWalk(scale, num_steps)
d2 = dist.GaussianStateSpace(num_steps, jnp.eye(1), scale_tril=scale * jnp.eye(1))
assert jnp.allclose(d1.variance, jnp.squeeze(d2.variance, axis=-1))

key = jax.random.key(18)
x1 = d1.sample(key, (3,))
x2 = d2.sample(key, (3,))
assert jnp.allclose(x1, jnp.squeeze(x2, axis=-1))

assert jnp.allclose(d1.log_prob(x1), d2.log_prob(x2))
Loading