Adapted jax_ops.py
to handle the lowering interfaces changed from jax v0.4.16 onwards.
#33
Job | Run time |
---|---|
4m 32s | |
12m 29s | |
1m 21s | |
8m 3s | |
7m 38s | |
4m 43s | |
5m 52s | |
13s | |
0s | |
0s | |
44m 51s |