Adapted jax_ops.py
to handle the lowering interfaces changed from jax v0.4.16 onwards.
#36
Job | Run time |
---|---|
5m 14s | |
17m 51s | |
17s | |
1m 57s | |
12m 11s | |
10m 22s | |
4m 49s | |
5m 27s | |
0s | |
0s | |
58m 8s |