Skip to content

Commit

Permalink
Add types for contrib.stochastic_support and improve the HSGP module (#…
Browse files Browse the repository at this point in the history
…1907)

* init

* better return hints

* type improvements

* better hints for jax arrays

* initializaze ell_

* undo change

* fix condition

* OMG stupid typo #facepalm!
  • Loading branch information
juanitorduz authored Nov 15, 2024
1 parent 0a2f6fe commit 66921bd
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 142 deletions.
50 changes: 25 additions & 25 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from __future__ import annotations

from jaxlib.xla_extension import ArrayImpl

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike

import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic
Expand All @@ -22,26 +22,26 @@


def _non_centered_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int | list[int]
) -> ArrayImpl:
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)


def _centered_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int | list[int]
) -> ArrayImpl:
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta


def linear_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int | list[int], non_centered: bool = True
) -> ArrayImpl:
phi: ArrayLike, spd: ArrayLike, m: int | list[int], non_centered: bool = True
) -> Array:
"""
Linear approximation formula of the Hilbert space Gaussian process.
Expand All @@ -52,27 +52,27 @@ def linear_approximation(
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl phi: laplacian eigenfunctions
:param ArrayImpl spd: square root of the diagonal of the spectral density evaluated at square
:param ArrayLike phi: laplacian eigenfunctions
:param ArrayLike spd: square root of the diagonal of the spectral density evaluated at square
root of the first `m` eigenvalues.
:param int | list[int] m: number of eigenfunctions in the approximation
:param bool non_centered: whether to use a non-centered parameterization
:return: The low-rank approximation linear model
:rtype: ArrayImpl
:rtype: Array
"""
if non_centered:
return _non_centered_approximation(phi, spd, m)
return _centered_approximation(phi, spd, m)


def hsgp_squared_exponential(
x: ArrayImpl,
x: ArrayLike,
alpha: float,
length: float,
ell: float | int | list[float | int],
m: int | list[int],
non_centered: bool = True,
) -> ArrayImpl:
) -> Array:
"""
Hilbert space Gaussian process approximation using the squared exponential kernel.
Expand All @@ -88,7 +88,7 @@ def hsgp_squared_exponential(
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl x: input data
:param ArrayLike x: input data
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
:param float | int | list[float | int] ell: positive value that parametrizes the length of the D-dimensional box so
Expand All @@ -99,9 +99,9 @@ def hsgp_squared_exponential(
If an integer, the same number of eigenvalues is computed in each dimension.
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True
:return: the low-rank approximation linear model
:rtype: ArrayImpl
:rtype: Array
"""
dim = x.shape[-1] if x.ndim > 1 else 1
dim = jnp.shape(x)[-1] if jnp.ndim(x) > 1 else 1
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_squared_exponential(
Expand All @@ -114,14 +114,14 @@ def hsgp_squared_exponential(


def hsgp_matern(
x: ArrayImpl,
x: ArrayLike,
nu: float,
alpha: float,
length: float,
ell: float | int | list[float | int],
m: int | list[int],
non_centered: bool = True,
):
) -> Array:
"""
Hilbert space Gaussian process approximation using the Matérn kernel.
Expand All @@ -137,7 +137,7 @@ def hsgp_matern(
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl x: input data
:param ArrayLike x: input data
:param float nu: smoothness parameter
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
Expand All @@ -149,9 +149,9 @@ def hsgp_matern(
If an integer, the same number of eigenvalues is computed in each dimension.
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True.
:return: the low-rank approximation linear model
:rtype: ArrayImpl
:rtype: Array
"""
dim = x.shape[-1] if x.ndim > 1 else 1
dim = jnp.shape(x)[-1] if jnp.ndim(x) > 1 else 1
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_matern(
Expand All @@ -164,8 +164,8 @@ def hsgp_matern(


def hsgp_periodic_non_centered(
x: ArrayImpl, alpha: float, length: float, w0: float, m: int
) -> ArrayImpl:
x: ArrayLike, alpha: float, length: float, w0: float, m: int
) -> Array:
"""
Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization.
Expand All @@ -176,13 +176,13 @@ def hsgp_periodic_non_centered(
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl x: input data
:param ArrayLike x: input data
:param float alpha: amplitude
:param float length: length scale
:param float w0: frequency of the periodic kernel
:param int m: number of eigenvalues to compute and include in the approximation
:return: the low-rank approximation linear model
:rtype: ArrayImpl
:rtype: Array
"""
q2 = diag_spectral_density_periodic(alpha=alpha, length=length, m=m)
cosines, sines = eigenfunctions_periodic(x=x, w0=w0, m=m)
Expand Down
65 changes: 31 additions & 34 deletions numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@

from __future__ import annotations

from typing import get_args

from jaxlib.xla_extension import ArrayImpl
import numpy as np

from jax import Array
import jax.numpy as jnp

from numpyro.contrib.hsgp.util import ARRAY_TYPE
from jax.typing import ArrayLike


def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl:
def eigenindices(m: list[int] | int, dim: int) -> Array:
"""Returns the indices of the first :math:`D \\times m^\\star` eigenvalues of the laplacian operator.
.. math::
Expand All @@ -35,7 +33,7 @@ def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl:
:param int dim: The dimension of the space.
:returns: An array of the indices of the first :math:`D \\times m^\\star` eigenvalues.
:rtype: ArrayImpl
:rtype: Array
**Examples:**
Expand Down Expand Up @@ -78,8 +76,8 @@ def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl:


def sqrt_eigenvalues(
ell: int | float | list[int | float], m: list[int] | int, dim: int
) -> ArrayImpl:
ell: ArrayLike | list[int | float], m: list[int] | int, dim: int
) -> Array:
"""
The first :math:`m^\\star \\times D` square root of eigenvalues of the laplacian operator in
:math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`. See Eq. (56) in [1].
Expand All @@ -96,16 +94,14 @@ def sqrt_eigenvalues(
:param int dim: The dimension of the space.
:returns: An array of the first :math:`m^\\star \\times D` square root of eigenvalues.
:rtype: ArrayImpl
:rtype: Array
"""
ell_ = _convert_ell(ell, dim)
S = eigenindices(m, dim)
return S * jnp.pi / 2 / ell_ # dim x prod(m) array of eigenvalues


def eigenfunctions(
x: ArrayImpl, ell: float | list[float], m: int | list[int]
) -> ArrayImpl:
def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) -> Array:
"""
The first :math:`m^\\star` eigenfunctions of the laplacian operator in
:math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`
Expand Down Expand Up @@ -141,7 +137,7 @@ def eigenfunctions(
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020)
:param ArrayImpl x: The points at which to evaluate the eigenfunctions.
:param ArrayLike x: The points at which to evaluate the eigenfunctions.
If `x` is 1D the problem is assumed unidimensional.
Otherwise, the dimension of the input space is inferred as the last dimension of `x`.
Other dimensions are treated as batch dimensions.
Expand All @@ -150,25 +146,27 @@ def eigenfunctions(
:param int | list[int] m: The number of eigenvalues to compute in each dimension.
If an integer, the same number of eigenvalues is computed in each dimension.
:returns: An array of the first :math:`m^\\star \\times D` eigenfunctions evaluated at `x`.
:rtype: ArrayImpl
:rtype: Array
"""
if x.ndim == 1:
x_ = x[..., None]
if jnp.ndim(x) == 1:
x_ = jnp.expand_dims(x, axis=-1)
else:
x_ = x
dim = x_.shape[-1] # others assumed batch dims
n_batch_dims = x_.ndim - 1
x_ = jnp.array(x)
dim = jnp.shape(x_)[-1] # others assumed batch dims
n_batch_dims = jnp.ndim(x_) - 1
ell_ = _convert_ell(ell, dim)
a = jnp.expand_dims(ell_, tuple(range(n_batch_dims)))
b = jnp.expand_dims(sqrt_eigenvalues(ell_, m, dim), tuple(range(n_batch_dims)))
return jnp.prod(jnp.sqrt(1 / a) * jnp.sin(b * (x_[..., None] + a)), axis=-2)
return jnp.prod(
jnp.sqrt(1 / a) * jnp.sin(b * (jnp.expand_dims(x_, axis=-1) + a)), axis=-2
)


def eigenfunctions_periodic(x: ArrayImpl, w0: float, m: int):
def eigenfunctions_periodic(x: ArrayLike, w0: float, m: int) -> tuple[Array, Array]:
"""
Basis functions for the approximation of the periodic kernel.
:param ArrayImpl x: The points at which to evaluate the eigenfunctions.
:param ArrayLike x: The points at which to evaluate the eigenfunctions.
:param float w0: The frequency of the periodic kernel.
:param int m: The number of eigenfunctions to compute.
Expand All @@ -178,43 +176,42 @@ def eigenfunctions_periodic(x: ArrayImpl, w0: float, m: int):
.. warning::
Multidimensional inputs are not supported.
"""
if x.ndim > 1:
if jnp.ndim(x) > 1:
raise ValueError(
"Multidimensional inputs are not supported by the periodic kernel."
)
m1 = jnp.tile(w0 * x[:, None], m)
m1 = jnp.tile(w0 * jnp.expand_dims(x, axis=-1), m)
m2 = jnp.diag(jnp.arange(m, dtype=jnp.float32))
mw0x = m1 @ m2
cosines = jnp.cos(mw0x)
sines = jnp.sin(mw0x)
return cosines, sines


def _convert_ell(
ell: float | int | list[float | int] | ArrayImpl, dim: int
) -> ArrayImpl:
def _convert_ell(ell: float | int | list[float | int] | ArrayLike, dim: int) -> Array:
"""
Process the half-length of the approximation interval and return a `D \\times 1` array.
If `ell` is a scalar, it is converted to a list of length dim, then transformed into an Array.
:param float | int | list[float | int] | ArrayImpl ell: The length of the interval in each dimension divided by 2.
:param float | int | list[float | int] | ArrayLike ell: The length of the interval in each dimension divided by 2.
If a float or int, the same length is used in each dimension.
:param int dim: The dimension of the space.
:returns: A `D \\times 1` array of the half-lengths of the approximation interval.
:rtype: ArrayImpl
:rtype: Array
"""
ell_ = jnp.empty((dim, 1))
if isinstance(ell, float) | isinstance(ell, int):
ell = [ell] * dim
ell = jnp.array([ell] * dim)[..., None]
if isinstance(ell, list):
if len(ell) != dim:
raise ValueError(
"The length of ell must be equal to the dimension of the space."
)
ell_ = jnp.array(ell)[..., None] # dim x 1 array
elif isinstance(ell, get_args(ARRAY_TYPE)):
ell_ = ell
if ell_.shape != (dim, 1):
elif isinstance(ell, Array) | isinstance(ell, np.ndarray):
ell_ = jnp.array(ell)
if jnp.shape(ell_) != (dim, 1):
raise ValueError("ell must be a scalar or a list of length `dim`.")
return ell_
Loading

0 comments on commit 66921bd

Please sign in to comment.