Skip to content

Commit

Permalink
fix condition
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 15, 2024
1 parent de05465 commit d26ec57
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 4 additions & 2 deletions numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from __future__ import annotations

import numpy as np

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike
Expand Down Expand Up @@ -208,8 +210,8 @@ def _convert_ell(ell: float | int | list[float | int] | ArrayLike, dim: int) ->
"The length of ell must be equal to the dimension of the space."
)
ell_ = jnp.array(ell)[..., None] # dim x 1 array
elif isinstance(ell, Array):
ell_ = ell
elif isinstance(ell, Array) | isinstance(ell, np.ndarray):
ell_ = jnp.array(ell)
if jnp.shape(ell_) != (dim, 1):
raise ValueError("ell must be a scalar or a list of length `dim`.")
return ell_
8 changes: 8 additions & 0 deletions test/contrib/hsgp/test_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,25 @@ def test_eigenfunctions(x: ArrayLike, ell: float | int, m: int | list[int]):
(1, 2, False),
([1, 1], 2, False),
(np.array([1, 1])[..., None], 2, False),
(jnp.array([1, 1])[..., None], 2, False),
(np.array([1, 1]), 2, True),
(jnp.array([1, 1]), 2, True),
([1, 1], 1, True),
(np.array([1, 1]), 1, True),
(jnp.array([1, 1]), 1, True),
],
ids=[
"ell-float",
"ell-int",
"ell-int-multdim",
"ell-list",
"ell-array",
"ell-jax-array",
"ell-array-fail",
"ell-jax-array-fail",
"dim-fail",
"dim-fail-array",
"dim-fail-jax",
],
)
def test_convert_ell(ell, dim, xfail):
Expand Down

0 comments on commit d26ec57

Please sign in to comment.