-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast atan and atan2 functions. #8388
base: main
Are you sure you want to change the base?
Conversation
|
GPU performance test was severely memory bandwidth limited. This has been worked around by computing many (1024) arctans per output and summing them. Now --at least on my system-- they are faster. See updated performance reports. |
Okay, this is ready for review. Vulkan is slow, but that is apparently known well... |
Oh dear... I don't even know what WebGPU is... @steven-johnson Is this supposed to be an actual platform that is fast, and where performance metrics make sense? I can treat it like Vulkan, where it's just "meh, at least some are faster..."? |
https://en.wikipedia.org/wiki/WebGPU |
I don't think Vulkan is necessarily slow ... I think the benchmark loop is including initialization overhead. See my follow up here: #7202 |
Very cool! I have some concerns with the error metric though. Decimal digits of error isn't a great metric. E.g. having a value of 0.0001 when it's supposed to be zero is much much worse than having a value of 0.3701 when it's supposed to be 0.37. Relative error isn't great either, due to the singularity at zero. A better metric is ULPs, which is the maximum number of distinct floating point values in between the answer and the correct answer. There are also cases where you want a hard constraint as opposed to a minimization. exp(0) should be exactly one, and I guess I decided its derivative should be exactly one too, which explains the different in coefficients. |
@abadams I improved the optimization script a lot. I added support for ULP optimization: it optimizes very nicely for maximal bit error. When instead optimizing for MAE, we see the max ULP distance increase: I changed the default to the ULP-optimized one, but to keep the maximal absolute error under 1e-5, I had to choose the higher-degree polynomial. Overall still good. @derek-gerstmann Thanks a lot for investigating the performance issue! I now also get very fast Vulkan performance. I wonder why the overhead is so huge in Vulkan, and not there in other backends? Vulkan:
CUDA:
Vulkan is now even faster than CUDA! 🤯 |
@steven-johnson The build just broke on something LLVM related it seems... There seems to be no related commit to Halide. Does LLVM constantly update with every build? Edit: I found the commit: llvm/llvm-project@75c7bca Fix separately PR'd in #8391 |
We rebuild LLVM once a day, about 2AM Pacific time. |
@abadams I added the check that counts number of wrong mantissa bits:
Pay attention to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cut polynomial + merge it + later take care of other transcendentals. |
…nge (-1, 1) to test (-4, 4). Cleanup code/comments. Test performance for all approximations.
9bcb9b7
to
b35f7fa
Compare
@abadams I updated the PR, and believe this is a nice compromise of options. It is in line with your initial thoughts on just specifying the precision yourself. I have made a table of approximations and their precisions. Then a new auxiliary function selects an approximation from that table that satisfies your requirements. This clears out the header (no more one million enum options), and clears out the source file, by not having the table sitting inside of the fast_atan function. |
Looks like this is ready for final review... ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nits
src/ApproximationTables.cpp
Outdated
} | ||
|
||
double score = obj_score + term_count_score + precision_score - penalty; | ||
// std::printf("Score for %zu (%zu terms): %f = %d + %d + %f - penalty %f\n", i, e.coefficients.size(), score, obj_score, term_count_score, precision_score, penalty); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove commented-out code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapped in an #if
src/ApproximationTables.h
Outdated
@@ -0,0 +1,21 @@ | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Halide doesn't use #pragma once
; instead, wrap in
#ifndef HALIDE_APPROXIMATION_TABLES_H_
#define HALIDE_APPROXIMATION_TABLES_H_
...
#endif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed that, but without the trailing _
, as that seemed to be the style, looking at other files.
@@ -219,8 +219,7 @@ target_sources( | |||
WrapCalls.h | |||
) | |||
|
|||
# The sources that go into libHalide. For the sake of IDE support, headers that | |||
# exist in src/ but are not public should be included here. | |||
# The sources that go into libHalide. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you alter the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because there are no headers in that list. That comments is clearly outdated. Unless I'm wildly misunderstanding something.
Is this ready to land (pending review comments)? |
Oh thanks @steven-johnson! I hadn't noticed your review! Will make work of this tomorrow! Ready to land if you ask me. @abadams wanted to see some changes, but he hasn't commented on what I did yet. |
Addresses #8243. Uses a polynomial approximation with odd powers: this way, it's immediately symmetrical around 0. Coefficients are optimized using my script which does iterative weight-adjusted least-squared-error (also included in PR; see below).
Added API
I designed this new
ApproximationPrecision
such that it can be used for other vectorizable functions at a later point as well, such as forfast_sin
andfast_cos
if we want that at some point. Note that I chose forMAE_1e_5
style of notation, instead of5Decimals
because 5 decimals suggests that there will be 5 decimals correct, which is technically less correct than saying that the maximal absolute error will be below1e-5
.Performance difference:
Linux/CPU (with precision
MAE_1e_5
):On Linux/CUDA, it's slightly faster than the default LLVM implementation (there is no atan instruction in PTX):
On Linux/OpenCL, it is also slightly faster:
Precision tests:
Optimizer
This PR includes a Python optimization script to find the coefficients of the polynomials:
While I didn't do anything very scientific or looked at research papers, I get a hunch that the results from this script are really good (and may actually converge to optimal).
If my optimization makes sense, then I have some funny observation: I get different coefficients for all of the fast approximations we have. See below.
Better coefficients for
exp()
?My result:
versus current Halide code:
Halide/src/IROperator.cpp
Lines 1432 to 1439 in 3cdeb53
Better coefficients for
sin()
?Notice that my optimization gives maximal error of 1.35e-11, instead of the promised 1e-5, with degree 6.
Versus:
Halide/src/IROperator.cpp
Lines 1390 to 1394 in 3cdeb53
If this is true (I don't see a reason why it wouldn't), that would mean we can remove a few terms to get faster version that still provides the promised precision.
Better coefficients for
cos()
?versus:
Halide/src/IROperator.cpp
Lines 1396 to 1400 in 3cdeb53
Better coefficients for
log()
?versus:
Halide/src/IROperator.cpp
Lines 1357 to 1365 in 3cdeb53