Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cake_eating_numerical] Numba time less than JAX #185

Open
mmcky opened this issue Jun 6, 2024 · 4 comments
Open

[cake_eating_numerical] Numba time less than JAX #185

mmcky opened this issue Jun 6, 2024 · 4 comments

Comments

@mmcky
Copy link
Contributor

mmcky commented Jun 6, 2024

The lecture cake_eating_numerical was removed in #175 as the jax time was longer than numba. We need to review the implementations in this lecture as a sense check.

A copy of the lecture is here:

cake_eating_numerical.md

and the timings were at the bottom of this preview

https://6661340ed985e83faa0cb785--incomparable-parfait-2417f8.netlify.app/cake_eating_numerical

We would expect jax to outperform numba unless there is a good reason that we should explain.

@kp992 do you have time to look into this lecture?

TODO:

  1. review implementations and confirm why numba is less than jax for execution time
  2. submit a PR updating and re-enabling this lecture
@kp992
Copy link
Contributor

kp992 commented Jun 6, 2024

Sure, will take a look.

@kp992
Copy link
Contributor

kp992 commented Jun 8, 2024

Hi @mmcky, I checked the difference in timings and the main reason is the that the difference in the algorithms used. JAX is surely optimized to the fullest but the algorithm used by JAX to find the maximum is a brute force approach where as numba uses brent_max function. Its currently unavailable in JAX implementation and so JAX is just using a brute force approach over the grid.

@kp992
Copy link
Contributor

kp992 commented Jun 8, 2024

If the brent_max part is available in JAX, we could beat numba in timings.

@mmcky
Copy link
Contributor Author

mmcky commented Jun 10, 2024

thanks @kp992 that is really helpful. Algorithms matter :-).

@jstac (Smit) has identified the issue in timings here and there is a good explanation as to why the numba execution is faster than jax. It is a good example of how algorithms matter (just as much as technology). What do you think about making this point in the lecture?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants