From d26ec57db3aaf4916518b9c0316cd9f002c602d7 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 15 Nov 2024 11:22:30 +0100 Subject: [PATCH] fix condition --- numpyro/contrib/hsgp/laplacian.py | 6 ++++-- test/contrib/hsgp/test_laplacian.py | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index 2dc756316..011c828aa 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -7,6 +7,8 @@ from __future__ import annotations +import numpy as np + from jax import Array import jax.numpy as jnp from jax.typing import ArrayLike @@ -208,8 +210,8 @@ def _convert_ell(ell: float | int | list[float | int] | ArrayLike, dim: int) -> "The length of ell must be equal to the dimension of the space." ) ell_ = jnp.array(ell)[..., None] # dim x 1 array - elif isinstance(ell, Array): - ell_ = ell + elif isinstance(ell, Array) | isinstance(ell, np.ndarray): + ell_ = jnp.array(ell) if jnp.shape(ell_) != (dim, 1): raise ValueError("ell must be a scalar or a list of length `dim`.") return ell_ diff --git a/test/contrib/hsgp/test_laplacian.py b/test/contrib/hsgp/test_laplacian.py index b50ff6ea0..978947a43 100644 --- a/test/contrib/hsgp/test_laplacian.py +++ b/test/contrib/hsgp/test_laplacian.py @@ -131,8 +131,12 @@ def test_eigenfunctions(x: ArrayLike, ell: float | int, m: int | list[int]): (1, 2, False), ([1, 1], 2, False), (np.array([1, 1])[..., None], 2, False), + (jnp.array([1, 1])[..., None], 2, False), (np.array([1, 1]), 2, True), + (jnp.array([1, 1]), 2, True), ([1, 1], 1, True), + (np.array([1, 1]), 1, True), + (jnp.array([1, 1]), 1, True), ], ids=[ "ell-float", @@ -140,8 +144,12 @@ def test_eigenfunctions(x: ArrayLike, ell: float | int, m: int | list[int]): "ell-int-multdim", "ell-list", "ell-array", + "ell-jax-array", "ell-array-fail", + "ell-jax-array-fail", "dim-fail", + "dim-fail-array", + "dim-fail-jax", ], ) def test_convert_ell(ell, dim, xfail):