Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast atan and atan2 functions. #8388

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cc3ef99
Fast vectorizable atan and atan2 functions.
mcourteaux Aug 10, 2024
56cab26
Default to not using fast atan versions if on CUDA.
mcourteaux Aug 10, 2024
59e6d35
Finished fast atan/atan2 functions and tests.
mcourteaux Aug 10, 2024
7c64aa2
Correct attribution.
mcourteaux Aug 10, 2024
020e966
Clang-format
mcourteaux Aug 10, 2024
cde21ca
Weird WebAssembly limits...
mcourteaux Aug 11, 2024
0881fdd
Small improvements to the optimization script.
mcourteaux Aug 11, 2024
73816c6
Polynomial optimization for log, exp, sin, cos with correct ranges.
mcourteaux Aug 11, 2024
a75d68b
Improve fast atan performance tests for GPU.
mcourteaux Aug 12, 2024
1c0f794
Bugfix fast_atan approximation. Fix correctness test to exceed the ra…
mcourteaux Aug 12, 2024
5d32551
Cleanup
mcourteaux Aug 12, 2024
11c5209
Enum class instead of enum for ApproximationPrecision.
mcourteaux Aug 12, 2024
f3b9d8f
Weird Metal limits. There should be a better way...
mcourteaux Aug 12, 2024
1a92308
Skip test for WebGPU.
mcourteaux Aug 12, 2024
775061a
Fast atan/atan2 polynomials reoptimized. New optimization strategy: ULP.
mcourteaux Aug 13, 2024
6cb4fac
Feedback Steven.
mcourteaux Aug 13, 2024
e9823c1
More comments and test mantissa error.
mcourteaux Aug 14, 2024
3ced523
Do not error when testing arctan performance on Metal / WebGPU.
mcourteaux Aug 14, 2024
b35f7fa
Rework precision specification. Generalize towards using this for oth…
mcourteaux Nov 11, 2024
c004f72
Clang-format.
mcourteaux Nov 11, 2024
34a5ff9
Fix makefile and clang-tidy.
mcourteaux Nov 11, 2024
9408c93
Fix incorrect approximation selection when required precision is not …
mcourteaux Nov 12, 2024
8efe869
Feedback from Steven.
mcourteaux Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ SOURCE_FILES = \
AlignLoads.cpp \
AllocationBoundsInference.cpp \
ApplySplit.cpp \
ApproximationTables.cpp \
Argument.cpp \
AssociativeOpsTable.cpp \
Associativity.cpp \
Expand Down
133 changes: 133 additions & 0 deletions src/ApproximationTables.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include "ApproximationTables.h"

namespace Halide {
namespace Internal {

namespace {

using OO = ApproximationPrecision::OptimizationObjective;

// clang-format off
// Generate this table with:
// python3 src/polynomial_optimizer.py atan --order 1 2 3 4 5 6 7 8 --loss mse mae mulpe mulpe_mae --no-gui --format table
mcourteaux marked this conversation as resolved.
Show resolved Hide resolved
//
// Note that the maximal errors are computed with numpy with double precision.
// The real errors are a bit larger with single-precision floats (see correctness/fast_arctan.cpp).
// Also note that ULP distances which are not units are bogus, but this is because this error
// was again measured with double precision, so the actual reconstruction had more bits of
// precision than the actual float32 target value. So in practice the MaxULP Error
// will be close to round(MaxUlpE).
const std::vector<Approximation> table_atan = {
{OO::MSE, 9.249650e-04, 7.078984e-02, 2.411e+06, {+8.56188008e-01}},
{OO::MSE, 1.026356e-05, 9.214909e-03, 3.985e+05, {+9.76213454e-01, -2.00030200e-01}},
{OO::MSE, 1.577588e-07, 1.323851e-03, 6.724e+04, {+9.95982073e-01, -2.92278128e-01, +8.30180680e-02}},
{OO::MSE, 2.849011e-09, 1.992218e-04, 1.142e+04, {+9.99316541e-01, -3.22286501e-01, +1.49032461e-01, -4.08635592e-02}},
{OO::MSE, 5.667504e-11, 3.080100e-05, 1.945e+03, {+9.99883373e-01, -3.30599535e-01, +1.81451316e-01, -8.71733830e-02, +2.18671936e-02}},
{OO::MSE, 1.202662e-12, 4.846916e-06, 3.318e+02, {+9.99980065e-01, -3.32694393e-01, +1.94019697e-01, -1.17694732e-01, +5.40822080e-02, -1.22995279e-02}},
{OO::MSE, 2.672889e-14, 7.722732e-07, 5.664e+01, {+9.99996589e-01, -3.33190090e-01, +1.98232868e-01, -1.32941469e-01, +8.07623712e-02, -3.46124853e-02, +7.15115276e-03}},
{OO::MSE, 6.147315e-16, 1.245768e-07, 9.764e+00, {+9.99999416e-01, -3.33302229e-01, +1.99511173e-01, -1.39332647e-01, +9.70944891e-02, -5.68823386e-02, +2.25679012e-02, -4.25772648e-03}},

{OO::MAE, 1.097847e-03, 4.801638e-02, 2.793e+06, {+8.33414544e-01}},
{OO::MAE, 1.209593e-05, 4.968992e-03, 4.623e+05, {+9.72410454e-01, -1.91981283e-01}},
{OO::MAE, 1.839382e-07, 6.107084e-04, 7.766e+04, {+9.95360080e-01, -2.88702052e-01, +7.93508437e-02}},
{OO::MAE, 3.296902e-09, 8.164167e-05, 1.313e+04, {+9.99214108e-01, -3.21178073e-01, +1.46272006e-01, -3.89915187e-02}},
{OO::MAE, 6.523525e-11, 1.147459e-05, 2.229e+03, {+9.99866373e-01, -3.30305517e-01, +1.80162434e-01, -8.51611537e-02, +2.08475020e-02}},
{OO::MAE, 1.378842e-12, 1.667328e-06, 3.792e+02, {+9.99977226e-01, -3.32622991e-01, +1.93541452e-01, -1.16429278e-01, +5.26504600e-02, -1.17203722e-02}},
{OO::MAE, 3.055131e-14, 2.480947e-07, 6.457e+01, {+9.99996113e-01, -3.33173716e-01, +1.98078484e-01, -1.32334692e-01, +7.96260166e-02, -3.36062649e-02, +6.81247117e-03}},
{OO::MAE, 7.013215e-16, 3.757868e-08, 1.102e+01, {+9.99999336e-01, -3.33298615e-01, +1.99465749e-01, -1.39086791e-01, +9.64233077e-02, -5.59142254e-02, +2.18643190e-02, -4.05495427e-03}},

{OO::MULPE, 1.355602e-03, 1.067325e-01, 1.808e+06, {+8.92130617e-01}},
{OO::MULPE, 2.100588e-05, 1.075508e-02, 1.822e+05, {+9.89111122e-01, -2.14468039e-01}},
{OO::MULPE, 3.573985e-07, 1.316370e-03, 2.227e+04, {+9.98665077e-01, -3.02990987e-01, +9.10404434e-02}},
{OO::MULPE, 6.474958e-09, 1.548508e-04, 2.619e+03, {+9.99842198e-01, -3.26272641e-01, +1.56294460e-01, -4.46207045e-02}},
{OO::MULPE, 1.313474e-10, 2.533532e-05, 4.294e+02, {+9.99974110e-01, -3.31823782e-01, +1.85886095e-01, -9.30024008e-02, +2.43894760e-02}},
{OO::MULPE, 3.007880e-12, 3.530685e-06, 5.983e+01, {+9.99996388e-01, -3.33036463e-01, +1.95959706e-01, -1.22068745e-01, +5.83403647e-02, -1.37966171e-02}},
{OO::MULPE, 6.348880e-14, 4.882649e-07, 8.276e+00, {+9.99999499e-01, -3.33273408e-01, +1.98895454e-01, -1.35153794e-01, +8.43185278e-02, -3.73434598e-02, +7.95583230e-03}},
{OO::MULPE, 1.369569e-15, 7.585036e-08, 1.284e+00, {+9.99999922e-01, -3.33320840e-01, +1.99708563e-01, -1.40257063e-01, +9.93094012e-02, -5.97138046e-02, +2.44056181e-02, -4.73371006e-03}},

{OO::MULPE_MAE, 9.548909e-04, 6.131488e-02, 2.570e+06, {+8.46713042e-01}},
{OO::MULPE_MAE, 1.159917e-05, 6.746680e-03, 3.778e+05, {+9.77449762e-01, -1.98798279e-01}},
{OO::MULPE_MAE, 1.783646e-07, 8.575388e-04, 6.042e+04, {+9.96388826e-01, -2.92591679e-01, +8.24585555e-02}},
{OO::MULPE_MAE, 3.265269e-09, 1.190548e-04, 9.505e+03, {+9.99430906e-01, -3.22774535e-01, +1.49370817e-01, -4.07480795e-02}},
{OO::MULPE_MAE, 6.574962e-11, 1.684690e-05, 1.515e+03, {+9.99909079e-01, -3.30795737e-01, +1.81810037e-01, -8.72860225e-02, +2.17776539e-02}},
{OO::MULPE_MAE, 1.380489e-12, 2.497538e-06, 2.510e+02, {+9.99984893e-01, -3.32748885e-01, +1.94193211e-01, -1.17865932e-01, +5.40633775e-02, -1.22309990e-02}},
{OO::MULPE_MAE, 3.053218e-14, 3.784868e-07, 4.181e+01, {+9.99997480e-01, -3.33205127e-01, +1.98309644e-01, -1.33094430e-01, +8.08643094e-02, -3.45859503e-02, +7.11261604e-03}},
{OO::MULPE_MAE, 7.018877e-16, 5.862915e-08, 6.942e+00, {+9.99999581e-01, -3.33306326e-01, +1.99542180e-01, -1.39433369e-01, +9.72462857e-02, -5.69734398e-02, +2.25639390e-02, -4.24074590e-03}},
};
// clang-format on
} // namespace

const Approximation *find_best_approximation(const std::vector<Approximation> &table,
ApproximationPrecision precision) {
#define DEBUG_APPROXIMATION_SEARCH 0
const Approximation *best = nullptr;
constexpr int term_cost = 20;
constexpr int extra_term_cost = 200;
double best_score = 0;
#if DEBUG_APPROXIMATION_SEARCH
std::printf("Looking for min_terms=%d, max_absolute_error=%f\n",
precision.constraint_min_poly_terms, precision.constraint_max_absolute_error);
#endif
for (size_t i = 0; i < table.size(); ++i) {
const Approximation &e = table[i];

double penalty = 0.0;

int obj_score = e.objective == precision.optimized_for ? 100 * term_cost : 0;
if (precision.optimized_for == ApproximationPrecision::MULPE_MAE &&
e.objective == ApproximationPrecision::MULPE) {
obj_score = 50 * term_cost; // When MULPE_MAE is not available, prefer MULPE.
}

int num_terms = int(e.coefficients.size());
int term_count_score = (12 - num_terms) * term_cost;
if (num_terms < precision.constraint_min_poly_terms) {
penalty += (precision.constraint_min_poly_terms - num_terms) * extra_term_cost;
}

double precision_score = 0;
// If we don't care about the maximum number of terms, we maximize precision.
switch (precision.optimized_for) {
case ApproximationPrecision::MSE:
precision_score = -std::log(e.mse);
break;
case ApproximationPrecision::MAE:
precision_score = -std::log(e.mae);
break;
case ApproximationPrecision::MULPE:
precision_score = -std::log(e.mulpe);
break;
case ApproximationPrecision::MULPE_MAE:
precision_score = -0.5 * std::log(e.mulpe * e.mae);
break;
}

if (precision.constraint_max_absolute_error > 0.0 &&
precision.constraint_max_absolute_error < e.mae) {
float error_ratio = e.mae / precision.constraint_max_absolute_error;
penalty += 20 * error_ratio * extra_term_cost; // penalty for not getting the required precision.
}

double score = obj_score + term_count_score + precision_score - penalty;
#if DEBUG_APPROXIMATION_SEARCH
std::printf("Score for %zu (%zu terms): %f = %d + %d + %f - penalty %f\n",
i, e.coefficients.size(), score, obj_score, term_count_score,
precision_score, penalty);
#endif
if (score > best_score || best == nullptr) {
best = &e;
best_score = score;
}
}
#if DEBUG_APPROXIMATION_SEARCH
std::printf("Best score: %f\n", best_score);
#endif
return best;
}

const Approximation *best_atan_approximation(Halide::ApproximationPrecision precision) {
return find_best_approximation(table_atan, precision);
}

} // namespace Internal
} // namespace Halide
24 changes: 24 additions & 0 deletions src/ApproximationTables.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef HALIDE_APPROXIMATION_TABLES_H
#define HALIDE_APPROXIMATION_TABLES_H

#include <vector>

#include "IROperator.h"

namespace Halide {
namespace Internal {

struct Approximation {
ApproximationPrecision::OptimizationObjective objective;
double mse;
double mae;
double mulpe;
std::vector<double> coefficients;
};

const Approximation *best_atan_approximation(Halide::ApproximationPrecision precision);

} // namespace Internal
} // namespace Halide

#endif
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ target_sources(
WrapCalls.h
)

# The sources that go into libHalide. For the sake of IDE support, headers that
# exist in src/ but are not public should be included here.
# The sources that go into libHalide.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you alter the comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there are no headers in that list. That comments is clearly outdated. Unless I'm wildly misunderstanding something.

target_sources(
Halide
PRIVATE
Expand All @@ -232,6 +231,7 @@ target_sources(
AlignLoads.cpp
AllocationBoundsInference.cpp
ApplySplit.cpp
ApproximationTables.cpp
Argument.cpp
AssociativeOpsTable.cpp
Associativity.cpp
Expand Down
58 changes: 57 additions & 1 deletion src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <sstream>
#include <utility>

#include "ApproximationTables.h"
#include "CSE.h"
#include "ConstantBounds.h"
#include "Debug.h"
Expand Down Expand Up @@ -1388,7 +1389,7 @@ Expr fast_sin_cos(const Expr &x_full, bool is_sin) {
Expr sin_usecos = is_sin ? ((k_mod4 == 1) || (k_mod4 == 3)) : ((k_mod4 == 0) || (k_mod4 == 2));
Expr flip_sign = is_sin ? (k_mod4 > 1) : ((k_mod4 == 1) || (k_mod4 == 2));

// Reduce the angle modulo pi/2.
// Reduce the angle modulo pi/2: i.e., to the angle within the quadrant.
Expr x = x_full - k_real * pi_over_two;

const float sin_c2 = -0.16666667163372039794921875f;
Expand Down Expand Up @@ -1425,6 +1426,61 @@ Expr fast_cos(const Expr &x_full) {
return fast_sin_cos(x_full, false);
}

// A vectorizable atan and atan2 implementation.
// Based on the ideas presented in https://mazzo.li/posts/vectorized-atan2.html.
Expr fast_atan_approximation(const Expr &x_full, ApproximationPrecision precision, bool between_m1_and_p1) {
const float pi_over_two = 1.57079632679489661923f;
Expr x;
// if x > 1 -> atan(x) = Pi/2 - atan(1/x)
Expr x_gt_1 = abs(x_full) > 1.0f;
if (between_m1_and_p1) {
x = x_full;
} else {
x = select(x_gt_1, 1.0f / x_full, x_full);
}
const Internal::Approximation *approx = Internal::best_atan_approximation(precision);
const std::vector<double> &c = approx->coefficients;
Expr x2 = x * x;
Expr result = float(c.back());
for (size_t i = 1; i < c.size(); ++i) {
result = x2 * result + float(c[c.size() - i - 1]);
}
result *= x;

if (!between_m1_and_p1) {
result = select(x_gt_1, select(x_full < 0, -pi_over_two, pi_over_two) - result, result);
}
return common_subexpression_elimination(result);
}

Expr fast_atan(const Expr &x_full, ApproximationPrecision precision) {
mcourteaux marked this conversation as resolved.
Show resolved Hide resolved
return fast_atan_approximation(x_full, precision, false);
}

Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision precision) {
const float pi = 3.14159265358979323846f;
const float pi_over_two = 1.57079632679489661923f;
// Making sure we take the ratio of the biggest number by the smallest number (in absolute value)
// will always give us a number between -1 and +1, which is the range over which the approximation
// works well. We can therefore also skip the inversion logic in the fast_atan_approximation function
// by passing true for "between_m1_and_p1". This increases both speed (1 division instead of 2) and
// numerical precision.
Expr swap = abs(y) > abs(x);
Expr atan_input = select(swap, x, y) / select(swap, y, x);
Expr ati = fast_atan_approximation(atan_input, precision, true);
Expr at = select(swap, select(atan_input >= 0.0f, pi_over_two, -pi_over_two) - ati, ati);
// This select statement is literally taken over from the definition on Wikipedia.
// There might be optimizations to be done here, but I haven't tried that yet. -- Martijn
Expr result = select(
x > 0.0f, at,
x < 0.0f && y >= 0.0f, at + pi,
x < 0.0f && y < 0.0f, at - pi,
x == 0.0f && y > 0.0f, pi_over_two,
x == 0.0f && y < 0.0f, -pi_over_two,
0.0f);
return common_subexpression_elimination(result);
}

Expr fast_exp(const Expr &x_full) {
user_assert(x_full.type() == Float(32)) << "fast_exp only works for Float(32)";

Expand Down
49 changes: 49 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,55 @@ Expr fast_sin(const Expr &x);
Expr fast_cos(const Expr &x);
// @}

/** Struct that allows the user to specify several requirements for functions
* that are approximated by polynomial expansions. These polynomials can be
* optimized for four different metrics: Mean Squared Error, Maximum Absolute Error,
* Maximum Units in Last Place (ULP) Error, or a 50%/50% blend of MAE and MULPE.
*
* Orthogonally to the optimization objective, these polynomials can vary
* in degree. Higher degree polynomials will give more precise results.
* Note that instead of specifying the degree, the number of terms is used instead.
* E.g., even (i.e., symmetric) functions may be implemented using only even powers,
* for which a number of terms of 4 would actually mean that terms
* in [1, x^2, x^4, x^6] are used, which is degree 6.
*
* Additionally, if you don't care about number of terms in the polynomial
* and you do care about the maximal absolute error the approximation may have
* over the domain, you may specify values and the implementation
* will decide the appropriate polynomial degree that achieves this precision.
*/
struct ApproximationPrecision {
enum OptimizationObjective {
MSE, //< Mean Squared Error Optimized.
MAE, //< Optimized for Max Absolute Error.
MULPE, //< Optimized for Max ULP Error. ULP is "Units in Last Place", measured in IEEE 32-bit floats.
MULPE_MAE, //< Optimized for simultaneously Max ULP Error, and Max Absolute Error, each with a weight of 50%.
} optimized_for;
int constraint_min_poly_terms{0}; //< Number of terms in polynomial (zero for no constraint).
float constraint_max_absolute_error{0.0f}; //< Max absolute error (zero for no constraint).
};

/** Fast vectorizable approximations for arctan and arctan2 for Float(32).
*
* Desired precision can be specified as either a maximum absolute error (MAE) or
* the number of terms in the polynomial approximation (see the ApproximationPrecision enum) which
* are optimized for either:
* - MSE (Mean Squared Error)
* - MAE (Maximum Absolute Error)
* - MULPE (Maximum Units in Last Place Error).
*
* The default (Max ULP Error Polynomial of 6 terms) has a MAE of 3.53e-6.
* For more info on the available approximations and their precisions, see the table in ApproximationTables.cpp.
*
* Note: the polynomial uses odd powers, so the number of terms is not the degree of the polynomial.
* Note: the polynomial with 8 terms is only useful to increase precision for fast_atan, and not for fast_atan2.
* Note: the performance of this functions seem to be not reliably faster on WebGPU (for now, August 2024).
*/
// @{
Expr fast_atan(const Expr &x, ApproximationPrecision precision = {ApproximationPrecision::MULPE, 6});
Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = {ApproximationPrecision::MULPE, 6});
// @}

/** Fast approximate cleanly vectorizable log for Float(32). Returns
* nonsense for x <= 0.0f. Accurate up to the last 5 bits of the
* mantissa. Vectorizes cleanly. */
Expand Down
Loading