diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index 6cb904f50..9d27f7a3f 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -7,8 +7,6 @@ from __future__ import annotations -from typing import cast - from jax import Array import jax.numpy as jnp from jax.typing import ArrayLike @@ -27,14 +25,14 @@ def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array: with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0)) - return cast(Array, phi @ (spd * beta)) + return jnp.asarray(phi @ (spd * beta)) def _centered_approximation(phi: Array, spd: Array, m: int) -> Array: with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) - return cast(Array, phi @ beta) + return jnp.asarray(phi @ beta) def linear_approximation(