Skip to content

Commit

Permalink
Make JAX dependencies compatible with new array equality behavior
Browse files Browse the repository at this point in the history
This makes libraries compatbile with the changes in jax-ml/jax#11234

PiperOrigin-RevId: 457663471
Change-Id: Ib4a169e013a3b04339cc6d6f4edf5b960a07087e
  • Loading branch information
Jake VanderPlas authored and jg8610 committed Aug 12, 2022
1 parent 8b0536c commit b8cfa9b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jraph/_src/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,13 +826,13 @@ def test_fully_connected_graph_order_edges(self, add_self_edges):
add_self_edges=add_self_edges)

if add_self_edges:
self.assertSequenceEqual(
np.testing.assert_array_equal(
graph_batch.senders, [0, 1, 2] * 3)
self.assertSequenceEqual(
np.testing.assert_array_equal(
graph_batch.receivers, [0] * 3 + [1] * 3 + [2] * 3)
else:
self.assertSequenceEqual(graph_batch.senders, [1, 2, 2, 0, 0, 1])
self.assertSequenceEqual(graph_batch.receivers, [0, 0, 1, 1, 2, 2])
np.testing.assert_array_equal(graph_batch.senders, [1, 2, 2, 0, 0, 1])
np.testing.assert_array_equal(graph_batch.receivers, [0, 0, 1, 1, 2, 2])


class ConcatenatedArgsWrapperTest(parameterized.TestCase):
Expand Down

0 comments on commit b8cfa9b

Please sign in to comment.