Skip to content

Commit

Permalink
seed fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
Eelco Hoogendoorn committed Nov 3, 2023
1 parent e8f695c commit f7b71b4
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions numga/backend/jax/test/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,22 @@ def test_performance():
print(time()-t)


def check_inverse(x, i):
def check_inverse(x, i, atol=1e-9):
print(x.subspace, i.subspace)
assert np.allclose((x * i - 1).values, 0, atol=1e-9)
assert np.allclose((i * x - 1).values, 0, atol=1e-9)
assert np.allclose((x * i - 1).values, 0, atol=atol)
assert np.allclose((i * x - 1).values, 0, atol=atol)


def test_inverse():
"""test some inversion in 6 dimensions"""
np.random.seed(0)
import time
ga = JaxContext(Algebra.from_pqr(6, 0, 0), dtype=jnp.float64)
N = 100
V = ga.subspace.vector()
x = random_subspace(ga, V, (N,))
t = time.time()
check_inverse(x, x.inverse_la())
check_inverse(x, x.inverse_la(), atol=1e-5)
print('la', time.time() - t)
t = time.time()
check_inverse(x, x.inverse_shirokov())
Expand All @@ -114,7 +115,7 @@ def test_inverse():
V = ga.subspace.even_grade()
x = random_subspace(ga, V, (N,))
t = time.time()
check_inverse(x, x.inverse_la())
check_inverse(x, x.inverse_la(), atol=1e-5)
print('la', time.time() - t)
t = time.time()
check_inverse(x, x.inverse_shirokov())
Expand All @@ -125,7 +126,7 @@ def test_inverse():
foo = (lambda x: x.inverse_la())
foo(x)
t = time.time()
check_inverse(x, foo(x))
check_inverse(x, foo(x), atol=1e-5)
print('la', time.time() - t)
foo = jax.jit(lambda x: x.inverse_shirokov())
foo(x)
Expand Down

0 comments on commit f7b71b4

Please sign in to comment.