Skip to content

Commit

Permalink
Fast atan/atan2 polynomials reoptimized. New optimization strategy: ULP.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcourteaux committed Aug 13, 2024
1 parent 8581d7a commit 8ff6701
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 146 deletions.
125 changes: 72 additions & 53 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,59 +1437,78 @@ Expr fast_atan_approximation(const Expr &x_full, ApproximationPrecision precisio
// Coefficients obtained using src/polynomial_optimizer.py
// Note that the maximal errors are computed with numpy with double precision.
// The real errors are a bit larger with single-precision floats (see correctness/fast_arctan.cpp).

// The table is huge, so let's put clang-format off and handle the layout manually:
// clang-format off
std::vector<float> c;
if (precision == ApproximationPrecision::MAE_1e_2 || precision == ApproximationPrecision::Poly2) {
// Coefficients with max error: 4.9977e-03
c.push_back(9.724422672912e-01f);
c.push_back(-1.920418089970e-01f);
} else if (precision == ApproximationPrecision::MAE_1e_3 || precision == ApproximationPrecision::Poly3) {
// Coefficients with max error: 6.1317e-04
c.push_back(9.953639222909e-01f);
c.push_back(-2.887227485229e-01f);
c.push_back(7.937016196576e-02f);
} else if (precision == ApproximationPrecision::MAE_1e_4 || precision == ApproximationPrecision::Poly4) {
// Coefficients with max error: 8.1862e-05
c.push_back(9.992146660828e-01f);
c.push_back(-3.211839266848e-01f);
c.push_back(1.462857116754e-01f);
c.push_back(-3.900014954510e-02f);
} else if (precision == ApproximationPrecision::Poly5) {
// Coefficients with max error: 1.1527e-05
c.push_back(9.998664595623e-01f);
c.push_back(-3.303069921053e-01f);
c.push_back(1.801687249421e-01f);
c.push_back(-8.517067470591e-02f);
c.push_back(2.085217296632e-02f);
} else if (precision == ApproximationPrecision::MAE_1e_5 || precision == ApproximationPrecision::Poly6) {
// Coefficients with max error: 1.6869e-06
c.push_back(9.999772493111e-01f);
c.push_back(-3.326235741278e-01f);
c.push_back(1.935452881570e-01f);
c.push_back(-1.164392687560e-01f);
c.push_back(5.266159827071e-02f);
c.push_back(-1.172481633666e-02f);
} else if (precision == ApproximationPrecision::MAE_1e_6 || precision == ApproximationPrecision::Poly7) {
// Coefficients with max error: 2.4856e-07
c.push_back(9.999961151054e-01f);
c.push_back(-3.331738028802e-01f);
c.push_back(1.980792937100e-01f);
c.push_back(-1.323378013498e-01f);
c.push_back(7.963167170570e-02f);
c.push_back(-3.361110979599e-02f);
c.push_back(6.814044980872e-03f);
} else if (precision == ApproximationPrecision::Poly8) {
// Coefficients with max error: 3.8005e-08
c.push_back(9.999993363468e-01f);
c.push_back(-3.332986419645e-01f);
c.push_back(1.994660800256e-01f);
c.push_back(-1.390885586782e-01f);
c.push_back(9.642807440478e-02f);
c.push_back(-5.592101944058e-02f);
c.push_back(2.186920026077e-02f);
c.push_back(-4.056345562152e-03f);
} else {
user_error << "Invalid precision specified to fast_atan";
}
switch (precision) {
// == MSE Optimized == //
case ApproximationPrecision::MSE_Poly2: // (MSE=1.0264e-05, MAE=9.2149e-03, MaxUlpE=3.9855e+05)
c = {+9.762134539879e-01f, -2.000301999499e-01f}; break;
case ApproximationPrecision::MSE_Poly3: // (MSE=1.5776e-07, MAE=1.3239e-03, MaxUlpE=6.7246e+04)
c = {+9.959820734941e-01f, -2.922781275652e-01f, +8.301806798764e-02f}; break;
case ApproximationPrecision::MSE_Poly4: // (MSE=2.8490e-09, MAE=1.9922e-04, MaxUlpE=1.1422e+04)
c = {+9.993165406918e-01f, -3.222865011143e-01f, +1.490324612527e-01f, -4.086355921512e-02f}; break;
case ApproximationPrecision::MSE_Poly5: // (MSE=5.6675e-11, MAE=3.0801e-05, MaxUlpE=1.9456e+03)
c = {+9.998833730470e-01f, -3.305995351168e-01f, +1.814513158372e-01f, -8.717338298570e-02f,
+2.186719361787e-02f}; break;
case ApproximationPrecision::MSE_Poly6: // (MSE=1.2027e-12, MAE=4.8469e-06, MaxUlpE=3.3187e+02)
c = {+9.999800646964e-01f, -3.326943930673e-01f, +1.940196968486e-01f, -1.176947321238e-01f,
+5.408220801540e-02f, -1.229952788751e-02f}; break;
case ApproximationPrecision::MSE_Poly7: // (MSE=2.6729e-14, MAE=7.7227e-07, MaxUlpE=5.6646e+01)
c = {+9.999965889517e-01f, -3.331900904961e-01f, +1.982328680483e-01f, -1.329414694644e-01f,
+8.076237117606e-02f, -3.461248530394e-02f, +7.151152759080e-03f}; break;
case ApproximationPrecision::MSE_Poly8: // (MSE=6.1506e-16, MAE=1.2419e-07, MaxUlpE=9.6914e+00)
c = {+9.999994159669e-01f, -3.333022219271e-01f, +1.995110884308e-01f, -1.393321817395e-01f,
+9.709319573480e-02f, -5.688043380309e-02f, +2.256648487698e-02f, -4.257308331872e-03f}; break;

// == MAE Optimized == //
case ApproximationPrecision::MAE_1e_2:
case ApproximationPrecision::MAE_Poly2: // (MSE=1.2096e-05, MAE=4.9690e-03, MaxUlpE=4.6233e+05)
c = {+9.724104536788e-01f, -1.919812827495e-01f}; break;
case ApproximationPrecision::MAE_1e_3:
case ApproximationPrecision::MAE_Poly3: // (MSE=1.8394e-07, MAE=6.1071e-04, MaxUlpE=7.7667e+04)
c = {+9.953600796593e-01f, -2.887020515559e-01f, +7.935084373856e-02f}; break;
case ApproximationPrecision::MAE_1e_4:
case ApproximationPrecision::MAE_Poly4: // (MSE=3.2969e-09, MAE=8.1642e-05, MaxUlpE=1.3136e+04)
c = {+9.992141075707e-01f, -3.211780734117e-01f, +1.462720063085e-01f, -3.899151874271e-02f}; break;
case ApproximationPrecision::MAE_Poly5: // (MSE=6.5235e-11, MAE=1.1475e-05, MaxUlpE=2.2296e+03)
c = {+9.998663727249e-01f, -3.303055171903e-01f, +1.801624340886e-01f, -8.516115366058e-02f,
+2.084750202717e-02f}; break;
case ApproximationPrecision::MAE_1e_5:
case ApproximationPrecision::MAE_Poly6: // (MSE=1.3788e-12, MAE=1.6673e-06, MaxUlpE=3.7921e+02)
c = {+9.999772256973e-01f, -3.326229914097e-01f, +1.935414518077e-01f, -1.164292778405e-01f,
+5.265046001895e-02f, -1.172037220425e-02f}; break;
case ApproximationPrecision::MAE_1e_6:
case ApproximationPrecision::MAE_Poly7: // (MSE=3.0551e-14, MAE=2.4809e-07, MaxUlpE=6.4572e+01)
c = {+9.999961125922e-01f, -3.331737159104e-01f, +1.980784841430e-01f, -1.323346922675e-01f,
+7.962601662878e-02f, -3.360626486524e-02f, +6.812471171209e-03f}; break;
case ApproximationPrecision::MAE_Poly8: // (MSE=7.0132e-16, MAE=3.7579e-08, MaxUlpE=1.1023e+01)
c = {+9.999993357462e-01f, -3.332986153129e-01f, +1.994657492754e-01f, -1.390867909988e-01f,
+9.642330770840e-02f, -5.591422536378e-02f, +2.186431903729e-02f, -4.054954273090e-03f}; break;


// == Max ULP Optimized == //
case ApproximationPrecision::MULPE_Poly2: // (MSE=2.1006e-05, MAE=1.0755e-02, MaxUlpE=1.8221e+05)
c = {+9.891111216318e-01f, -2.144680385336e-01f}; break;
case ApproximationPrecision::MULPE_Poly3: // (MSE=3.5740e-07, MAE=1.3164e-03, MaxUlpE=2.2273e+04)
c = {+9.986650768126e-01f, -3.029909865833e-01f, +9.104044335898e-02f}; break;
case ApproximationPrecision::MULPE_Poly4: // (MSE=6.4750e-09, MAE=1.5485e-04, MaxUlpE=2.6199e+03)
c = {+9.998421981586e-01f, -3.262726405770e-01f, +1.562944595469e-01f, -4.462070448745e-02f}; break;
case ApproximationPrecision::MULPE_Poly5: // (MSE=1.3135e-10, MAE=2.5335e-05, MaxUlpE=4.2948e+02)
c = {+9.999741103798e-01f, -3.318237821017e-01f, +1.858860952571e-01f, -9.300240079057e-02f,
+2.438947597681e-02f}; break;
case ApproximationPrecision::MULPE_Poly6: // (MSE=3.0079e-12, MAE=3.5307e-06, MaxUlpE=5.9838e+01)
c = {+9.999963876702e-01f, -3.330364633925e-01f, +1.959597060284e-01f, -1.220687452250e-01f,
+5.834036471395e-02f, -1.379661708254e-02f}; break;
case ApproximationPrecision::MULPE_Poly7: // (MSE=6.3489e-14, MAE=4.8826e-07, MaxUlpE=8.2764e+00)
c = {+9.999994992400e-01f, -3.332734078379e-01f, +1.988954540598e-01f, -1.351537940907e-01f,
+8.431852775558e-02f, -3.734345976535e-02f, +7.955832300869e-03f}; break;
case ApproximationPrecision::MULPE_Poly8: // (MSE=1.3696e-15, MAE=7.5850e-08, MaxUlpE=1.2850e+00)
c = {+9.999999220612e-01f, -3.333208398432e-01f, +1.997085632112e-01f, -1.402570625577e-01f,
+9.930940122930e-02f, -5.971380457112e-02f, +2.440561807586e-02f, -4.733710058459e-03f}; break;
}
// clang-format on

Expr x2 = x * x;
Expr result = c.back();
Expand All @@ -1508,7 +1527,7 @@ Expr fast_atan(const Expr &x_full, ApproximationPrecision precision) {
}

Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision precision) {
const float pi(3.14159265358979323846f);
const float pi = 3.14159265358979323846f;
const float pi_over_two = 1.57079632679489661923f;
// Making sure we take the ratio of the biggest number by the smallest number (in absolute value)
// will always give us a number between -1 and +1, which is the range over which the approximation
Expand Down
63 changes: 48 additions & 15 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,32 +973,65 @@ Expr fast_cos(const Expr &x);
// @}

enum class ApproximationPrecision {
// Maximum Absolute error
/** Mean Squared Error Optimized. */
// @{
MSE_Poly2,
MSE_Poly3,
MSE_Poly4,
MSE_Poly5,
MSE_Poly6,
MSE_Poly7,
MSE_Poly8,
// @}

/* Maximum Absolute Error Optimized. */
// @{
MAE_1e_2,
MAE_1e_3,
MAE_1e_4,
MAE_1e_5,
MAE_1e_6,

// Number of terms in polynomial
Poly2,
Poly3,
Poly4,
Poly5,
Poly6,
Poly7,
Poly8
// @}

/** Number of terms in polynomial -- Optimized for Max Absolute Error. */
// @{
MAE_Poly2,
MAE_Poly3,
MAE_Poly4,
MAE_Poly5,
MAE_Poly6,
MAE_Poly7,
MAE_Poly8,
// @}

/** Number of terms in polynomial -- Optimized for Max ULP Error.
* ULP is "Units in Last Place", measured in IEEE 32-bit floats. */
// @{
MULPE_Poly2,
MULPE_Poly3,
MULPE_Poly4,
MULPE_Poly5,
MULPE_Poly6,
MULPE_Poly7,
MULPE_Poly8,
// @}
};
/** Fast vectorizable approximations for arctan for Float(32).
/** Fast vectorizable approximations for arctan and arctan2 for Float(32).
* Desired precision can be specified as either a maximum absolute error (MAE) or
* the number of terms in the polynomial approximation (see the ApproximationPrecision enum).
* the number of terms in the polynomial approximation (see the ApproximationPrecision enum) which
* are optimized for either:
* - MSE (Mean Squared Error)
* - MAE (Maximum Absolute Error)
* - MULPE (Maximum Units in Last Place Error).
* The default (Max ULP Error Polynomial 6) has a MAE of 3.53e-6. For more info on the precision,
* see the table in IROperator.cpp.
*
* Note: the polynomial uses odd powers, so the number of terms is not the degree of the polynomial.
* Note: Poly8 is only useful to increase precision for atan, and not for atan2.
* Note: The performance of this functions seem to be not reliably faster on WebGPU (for now, August 2024).
*/
// @{
Expr fast_atan(const Expr &x, ApproximationPrecision precision = ApproximationPrecision::MAE_1e_5);
Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = ApproximationPrecision::MAE_1e_5);
Expr fast_atan(const Expr &x, ApproximationPrecision precision = ApproximationPrecision::MULPE_Poly6);
Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = ApproximationPrecision::MULPE_Poly6);
// @}

/** Fast approximate cleanly vectorizable log for Float(32). Returns
Expand Down
Loading

0 comments on commit 8ff6701

Please sign in to comment.