Skip to content

Commit

Permalink
use as array
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 20, 2024
1 parent 02e149a commit ebc510c
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from __future__ import annotations

from typing import cast

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike
Expand All @@ -27,14 +25,14 @@ def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return cast(Array, phi @ (spd * beta))
return jnp.asarray(phi @ (spd * beta))


def _centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return cast(Array, phi @ beta)
return jnp.asarray(phi @ beta)


def linear_approximation(
Expand Down

0 comments on commit ebc510c

Please sign in to comment.