diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 9a601dcd993c..7d5cba81ec4d 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -1437,59 +1437,78 @@ Expr fast_atan_approximation(const Expr &x_full, ApproximationPrecision precisio // Coefficients obtained using src/polynomial_optimizer.py // Note that the maximal errors are computed with numpy with double precision. // The real errors are a bit larger with single-precision floats (see correctness/fast_arctan.cpp). + + // The table is huge, so let's put clang-format off and handle the layout manually: + // clang-format off std::vector c; - if (precision == ApproximationPrecision::MAE_1e_2 || precision == ApproximationPrecision::Poly2) { - // Coefficients with max error: 4.9977e-03 - c.push_back(9.724422672912e-01f); - c.push_back(-1.920418089970e-01f); - } else if (precision == ApproximationPrecision::MAE_1e_3 || precision == ApproximationPrecision::Poly3) { - // Coefficients with max error: 6.1317e-04 - c.push_back(9.953639222909e-01f); - c.push_back(-2.887227485229e-01f); - c.push_back(7.937016196576e-02f); - } else if (precision == ApproximationPrecision::MAE_1e_4 || precision == ApproximationPrecision::Poly4) { - // Coefficients with max error: 8.1862e-05 - c.push_back(9.992146660828e-01f); - c.push_back(-3.211839266848e-01f); - c.push_back(1.462857116754e-01f); - c.push_back(-3.900014954510e-02f); - } else if (precision == ApproximationPrecision::Poly5) { - // Coefficients with max error: 1.1527e-05 - c.push_back(9.998664595623e-01f); - c.push_back(-3.303069921053e-01f); - c.push_back(1.801687249421e-01f); - c.push_back(-8.517067470591e-02f); - c.push_back(2.085217296632e-02f); - } else if (precision == ApproximationPrecision::MAE_1e_5 || precision == ApproximationPrecision::Poly6) { - // Coefficients with max error: 1.6869e-06 - c.push_back(9.999772493111e-01f); - c.push_back(-3.326235741278e-01f); - c.push_back(1.935452881570e-01f); - c.push_back(-1.164392687560e-01f); - c.push_back(5.266159827071e-02f); - c.push_back(-1.172481633666e-02f); - } else if (precision == ApproximationPrecision::MAE_1e_6 || precision == ApproximationPrecision::Poly7) { - // Coefficients with max error: 2.4856e-07 - c.push_back(9.999961151054e-01f); - c.push_back(-3.331738028802e-01f); - c.push_back(1.980792937100e-01f); - c.push_back(-1.323378013498e-01f); - c.push_back(7.963167170570e-02f); - c.push_back(-3.361110979599e-02f); - c.push_back(6.814044980872e-03f); - } else if (precision == ApproximationPrecision::Poly8) { - // Coefficients with max error: 3.8005e-08 - c.push_back(9.999993363468e-01f); - c.push_back(-3.332986419645e-01f); - c.push_back(1.994660800256e-01f); - c.push_back(-1.390885586782e-01f); - c.push_back(9.642807440478e-02f); - c.push_back(-5.592101944058e-02f); - c.push_back(2.186920026077e-02f); - c.push_back(-4.056345562152e-03f); - } else { - user_error << "Invalid precision specified to fast_atan"; - } + switch (precision) { + // == MSE Optimized == // + case ApproximationPrecision::MSE_Poly2: // (MSE=1.0264e-05, MAE=9.2149e-03, MaxUlpE=3.9855e+05) + c = {+9.762134539879e-01f, -2.000301999499e-01f}; break; + case ApproximationPrecision::MSE_Poly3: // (MSE=1.5776e-07, MAE=1.3239e-03, MaxUlpE=6.7246e+04) + c = {+9.959820734941e-01f, -2.922781275652e-01f, +8.301806798764e-02f}; break; + case ApproximationPrecision::MSE_Poly4: // (MSE=2.8490e-09, MAE=1.9922e-04, MaxUlpE=1.1422e+04) + c = {+9.993165406918e-01f, -3.222865011143e-01f, +1.490324612527e-01f, -4.086355921512e-02f}; break; + case ApproximationPrecision::MSE_Poly5: // (MSE=5.6675e-11, MAE=3.0801e-05, MaxUlpE=1.9456e+03) + c = {+9.998833730470e-01f, -3.305995351168e-01f, +1.814513158372e-01f, -8.717338298570e-02f, + +2.186719361787e-02f}; break; + case ApproximationPrecision::MSE_Poly6: // (MSE=1.2027e-12, MAE=4.8469e-06, MaxUlpE=3.3187e+02) + c = {+9.999800646964e-01f, -3.326943930673e-01f, +1.940196968486e-01f, -1.176947321238e-01f, + +5.408220801540e-02f, -1.229952788751e-02f}; break; + case ApproximationPrecision::MSE_Poly7: // (MSE=2.6729e-14, MAE=7.7227e-07, MaxUlpE=5.6646e+01) + c = {+9.999965889517e-01f, -3.331900904961e-01f, +1.982328680483e-01f, -1.329414694644e-01f, + +8.076237117606e-02f, -3.461248530394e-02f, +7.151152759080e-03f}; break; + case ApproximationPrecision::MSE_Poly8: // (MSE=6.1506e-16, MAE=1.2419e-07, MaxUlpE=9.6914e+00) + c = {+9.999994159669e-01f, -3.333022219271e-01f, +1.995110884308e-01f, -1.393321817395e-01f, + +9.709319573480e-02f, -5.688043380309e-02f, +2.256648487698e-02f, -4.257308331872e-03f}; break; + + // == MAE Optimized == // + case ApproximationPrecision::MAE_1e_2: + case ApproximationPrecision::MAE_Poly2: // (MSE=1.2096e-05, MAE=4.9690e-03, MaxUlpE=4.6233e+05) + c = {+9.724104536788e-01f, -1.919812827495e-01f}; break; + case ApproximationPrecision::MAE_1e_3: + case ApproximationPrecision::MAE_Poly3: // (MSE=1.8394e-07, MAE=6.1071e-04, MaxUlpE=7.7667e+04) + c = {+9.953600796593e-01f, -2.887020515559e-01f, +7.935084373856e-02f}; break; + case ApproximationPrecision::MAE_1e_4: + case ApproximationPrecision::MAE_Poly4: // (MSE=3.2969e-09, MAE=8.1642e-05, MaxUlpE=1.3136e+04) + c = {+9.992141075707e-01f, -3.211780734117e-01f, +1.462720063085e-01f, -3.899151874271e-02f}; break; + case ApproximationPrecision::MAE_Poly5: // (MSE=6.5235e-11, MAE=1.1475e-05, MaxUlpE=2.2296e+03) + c = {+9.998663727249e-01f, -3.303055171903e-01f, +1.801624340886e-01f, -8.516115366058e-02f, + +2.084750202717e-02f}; break; + case ApproximationPrecision::MAE_1e_5: + case ApproximationPrecision::MAE_Poly6: // (MSE=1.3788e-12, MAE=1.6673e-06, MaxUlpE=3.7921e+02) + c = {+9.999772256973e-01f, -3.326229914097e-01f, +1.935414518077e-01f, -1.164292778405e-01f, + +5.265046001895e-02f, -1.172037220425e-02f}; break; + case ApproximationPrecision::MAE_1e_6: + case ApproximationPrecision::MAE_Poly7: // (MSE=3.0551e-14, MAE=2.4809e-07, MaxUlpE=6.4572e+01) + c = {+9.999961125922e-01f, -3.331737159104e-01f, +1.980784841430e-01f, -1.323346922675e-01f, + +7.962601662878e-02f, -3.360626486524e-02f, +6.812471171209e-03f}; break; + case ApproximationPrecision::MAE_Poly8: // (MSE=7.0132e-16, MAE=3.7579e-08, MaxUlpE=1.1023e+01) + c = {+9.999993357462e-01f, -3.332986153129e-01f, +1.994657492754e-01f, -1.390867909988e-01f, + +9.642330770840e-02f, -5.591422536378e-02f, +2.186431903729e-02f, -4.054954273090e-03f}; break; + + + // == Max ULP Optimized == // + case ApproximationPrecision::MULPE_Poly2: // (MSE=2.1006e-05, MAE=1.0755e-02, MaxUlpE=1.8221e+05) + c = {+9.891111216318e-01f, -2.144680385336e-01f}; break; + case ApproximationPrecision::MULPE_Poly3: // (MSE=3.5740e-07, MAE=1.3164e-03, MaxUlpE=2.2273e+04) + c = {+9.986650768126e-01f, -3.029909865833e-01f, +9.104044335898e-02f}; break; + case ApproximationPrecision::MULPE_Poly4: // (MSE=6.4750e-09, MAE=1.5485e-04, MaxUlpE=2.6199e+03) + c = {+9.998421981586e-01f, -3.262726405770e-01f, +1.562944595469e-01f, -4.462070448745e-02f}; break; + case ApproximationPrecision::MULPE_Poly5: // (MSE=1.3135e-10, MAE=2.5335e-05, MaxUlpE=4.2948e+02) + c = {+9.999741103798e-01f, -3.318237821017e-01f, +1.858860952571e-01f, -9.300240079057e-02f, + +2.438947597681e-02f}; break; + case ApproximationPrecision::MULPE_Poly6: // (MSE=3.0079e-12, MAE=3.5307e-06, MaxUlpE=5.9838e+01) + c = {+9.999963876702e-01f, -3.330364633925e-01f, +1.959597060284e-01f, -1.220687452250e-01f, + +5.834036471395e-02f, -1.379661708254e-02f}; break; + case ApproximationPrecision::MULPE_Poly7: // (MSE=6.3489e-14, MAE=4.8826e-07, MaxUlpE=8.2764e+00) + c = {+9.999994992400e-01f, -3.332734078379e-01f, +1.988954540598e-01f, -1.351537940907e-01f, + +8.431852775558e-02f, -3.734345976535e-02f, +7.955832300869e-03f}; break; + case ApproximationPrecision::MULPE_Poly8: // (MSE=1.3696e-15, MAE=7.5850e-08, MaxUlpE=1.2850e+00) + c = {+9.999999220612e-01f, -3.333208398432e-01f, +1.997085632112e-01f, -1.402570625577e-01f, + +9.930940122930e-02f, -5.971380457112e-02f, +2.440561807586e-02f, -4.733710058459e-03f}; break; + } + // clang-format on Expr x2 = x * x; Expr result = c.back(); @@ -1508,7 +1527,7 @@ Expr fast_atan(const Expr &x_full, ApproximationPrecision precision) { } Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision precision) { - const float pi(3.14159265358979323846f); + const float pi = 3.14159265358979323846f; const float pi_over_two = 1.57079632679489661923f; // Making sure we take the ratio of the biggest number by the smallest number (in absolute value) // will always give us a number between -1 and +1, which is the range over which the approximation diff --git a/src/IROperator.h b/src/IROperator.h index 2a6b859f8901..984d301c7272 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -973,32 +973,65 @@ Expr fast_cos(const Expr &x); // @} enum class ApproximationPrecision { - // Maximum Absolute error + /** Mean Squared Error Optimized. */ + // @{ + MSE_Poly2, + MSE_Poly3, + MSE_Poly4, + MSE_Poly5, + MSE_Poly6, + MSE_Poly7, + MSE_Poly8, + // @} + + /* Maximum Absolute Error Optimized. */ + // @{ MAE_1e_2, MAE_1e_3, MAE_1e_4, MAE_1e_5, MAE_1e_6, - - // Number of terms in polynomial - Poly2, - Poly3, - Poly4, - Poly5, - Poly6, - Poly7, - Poly8 + // @} + + /** Number of terms in polynomial -- Optimized for Max Absolute Error. */ + // @{ + MAE_Poly2, + MAE_Poly3, + MAE_Poly4, + MAE_Poly5, + MAE_Poly6, + MAE_Poly7, + MAE_Poly8, + // @} + + /** Number of terms in polynomial -- Optimized for Max ULP Error. + * ULP is "Units in Last Place", measured in IEEE 32-bit floats. */ + // @{ + MULPE_Poly2, + MULPE_Poly3, + MULPE_Poly4, + MULPE_Poly5, + MULPE_Poly6, + MULPE_Poly7, + MULPE_Poly8, + // @} }; -/** Fast vectorizable approximations for arctan for Float(32). +/** Fast vectorizable approximations for arctan and arctan2 for Float(32). * Desired precision can be specified as either a maximum absolute error (MAE) or - * the number of terms in the polynomial approximation (see the ApproximationPrecision enum). + * the number of terms in the polynomial approximation (see the ApproximationPrecision enum) which + * are optimized for either: + * - MSE (Mean Squared Error) + * - MAE (Maximum Absolute Error) + * - MULPE (Maximum Units in Last Place Error). + * The default (Max ULP Error Polynomial 6) has a MAE of 3.53e-6. For more info on the precision, + * see the table in IROperator.cpp. + * * Note: the polynomial uses odd powers, so the number of terms is not the degree of the polynomial. * Note: Poly8 is only useful to increase precision for atan, and not for atan2. - * Note: The performance of this functions seem to be not reliably faster on WebGPU (for now, August 2024). */ // @{ -Expr fast_atan(const Expr &x, ApproximationPrecision precision = ApproximationPrecision::MAE_1e_5); -Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = ApproximationPrecision::MAE_1e_5); +Expr fast_atan(const Expr &x, ApproximationPrecision precision = ApproximationPrecision::MULPE_Poly6); +Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = ApproximationPrecision::MULPE_Poly6); // @} /** Fast approximate cleanly vectorizable log for Float(32). Returns diff --git a/src/polynomial_optimizer.py b/src/polynomial_optimizer.py index 51b9af78fd57..5b89d0825ff2 100644 --- a/src/polynomial_optimizer.py +++ b/src/polynomial_optimizer.py @@ -6,6 +6,11 @@ parser = argparse.ArgumentParser() parser.add_argument("func") parser.add_argument("order", type=int) +parser.add_argument("loss", choices=["mse", "mae", "mulpe", "mulpe_mae"], default="mulpe") +parser.add_argument("--no-gui", action='store_true') +parser.add_argument("--print", action='store_true') +parser.add_argument("--pbar", action='store_true') +parser.add_argument("--format", default="all", choices=["all", "switch", "array", "consts"]) args = parser.parse_args() order = args.order @@ -41,113 +46,187 @@ X = np.linspace(lower, upper, 2048 * 8) target = func(X) +target_spacing = np.spacing(np.abs(target).astype(np.float32)).astype(np.float64) # Precision (aka ULP) print("exponent:", exponents) coeffs = np.zeros(len(exponents)) powers = np.power(X[:,None], exponents) -loss_power = 120 +loss_power = 500 -lstsq_iterations = 15000 -loss_history = np.zeros((lstsq_iterations, 2)) +lstsq_iterations = loss_power * 10 # If the loss is MSE, then this is just a linear system we can solve for. # We will iteratively adjust the weights to put more focus on the parts where it goes wrong. weight = np.ones_like(target) +if args.loss == "mse": + lstsq_iterations = 1 + +loss_history = np.zeros((lstsq_iterations, 3)) + +iterator = range(lstsq_iterations) +if args.pbar: + import tqdm + iterator = tqdm.trange(lstsq_iterations) + try: - for i in range(lstsq_iterations): + for i in iterator: norm_weight = weight / np.mean(weight) coeffs, residuals, rank, s = np.linalg.lstsq(powers * norm_weight[:,None], target * norm_weight, rcond=None) - if i == 0: - init_coeffs = coeffs.copy() y_hat = np.sum((powers * coeffs)[:,::-1], axis=-1) diff = y_hat - target abs_diff = np.abs(diff) - max_abs_error = np.amax(np.abs(diff)) - if i % 10 == 0: - print("coefficients:", coeffs, f" MaxAE: {max_abs_error:20.17f} mean weight: {weight.mean():10.8f}") - norm_abs_diff = abs_diff / np.mean(abs_diff) - p = i / lstsq_iterations - p = min(np.sqrt(p) * 1.25, 1.0) - weight += np.power(norm_abs_diff, 2 + int(loss_power * p) // 2 * 2) - loss = np.power(diff, loss_power) - loss_history[i, 0] = np.mean(loss) + # MSE metric + mean_squared_error = np.mean(np.square(diff)) + # MAE metric + max_abs_error = np.amax(abs_diff) loss_history[i, 1] = max_abs_error + # MaxULP metric + ulp_error = diff / target_spacing + abs_ulp_error = np.abs(ulp_error) + max_ulp_error = np.amax(abs_ulp_error) + loss_history[i, 2] = max_ulp_error + + if args.print and i % 10 == 0: + print(f"[{((i+1) / lstsq_iterations * 100.0):3.0f}%] coefficients:", coeffs, + f" MaxAE: {max_abs_error:20.17f} MaxULPs: {max_ulp_error:20.0f} mean weight: {weight.mean():.4e}") + + if args.loss == "mae": + norm_error_metric = abs_diff / np.amax(abs_diff) + elif args.loss == "mulpe": + norm_error_metric = abs_ulp_error / max_ulp_error + elif args.loss == "mulpe_mae": + norm_error_metric = 0.5 * (abs_ulp_error / max_ulp_error + abs_diff / max_abs_error) + elif args.loss == "mse": + norm_error_metric = np.square(abs_diff) + + p = i / lstsq_iterations + p = min(p * 1.25, 1.0) + raised_error = np.power(norm_error_metric, 2 + loss_power * p) + #weight += raised_error / np.mean(raised_error) + weight += raised_error + + mean_loss = np.mean(np.power(abs_diff, loss_power)) + loss_history[i, 0] = mean_loss + + if i == 0: + init_coeffs = coeffs.copy() + init_ulp_error = ulp_error.copy() + init_abs_ulp_error = abs_ulp_error.copy() + init_abs_error = abs_diff.copy() + init_y_hat = y_hat.copy() except KeyboardInterrupt: print("Interrupted") -print(coeffs) -y_hat = np.sum((powers * coeffs)[:,::-1], axis=-1) -y_hat_init = np.sum((powers * init_coeffs)[:,::-1], axis=-1) -diff = y_hat - target -loss = np.power(diff, loss_power) -mean_loss = np.mean(loss) -diff = y_hat - target -print(f"mse: {mean_loss:40.27f} max abs error: {max_abs_error:20.17f}") +print("Init coeffs:", init_coeffs) +print("Final coeffs:", coeffs) +print(f"mse: {mean_loss:40.27f} max abs error: {max_abs_error:20.17f} max ulp error: {max_ulp_error:e}") -print() -print(f"// Coefficients with max error: {max_abs_error:.4e}") -for i, (e, c) in enumerate(zip(exponents, coeffs)): - print(f"const float c_{e}({c:+.12e}f);") -print() +def print_comment(indent=""): + print(indent + "// " + + {"mae": "Max Absolute Error", "mse": "Mean Squared Error", "mulpe": "Max ULP Error", "mulpe_mae": "MaxUlpAE"}[args.loss] + + f" optimized (MSE={mean_squared_error:.4e}, MAE={max_abs_error:.4e}, MaxUlpE={max_ulp_error:.4e})") + + +if args.format in ["all", "consts"]: + print() + print_comment() + for i, (e, c) in enumerate(zip(exponents, coeffs)): + print(f"const float c_{e}({c:+.12e}f);") + print() + + +if args.format in ["all", "array"]: + print() + print_comment() + print("const float coef[] = {"); + for i, (e, c) in enumerate(reversed(list(zip(exponents, coeffs)))): + print(f" {c:+.12e}, // * x^{e}") + print("};\n") + +if args.format in ["all", "switch"]: + print() + print("case ApproximationPrecision::" + args.loss.upper() + "_Poly" + str(args.order) + ":" + + f" // (MSE={mean_squared_error:.4e}, MAE={max_abs_error:.4e}, MaxUlpE={max_ulp_error:.4e})") + print(" c = {" + (", ".join([f"{c:+.12e}f" for c in coeffs])) + "}; break;") + print() -print() -print(f"// Coefficients with max error: {max_abs_error:.4e}") -print("const float coef[] = {"); -for i, (e, c) in enumerate(reversed(list(zip(exponents, coeffs)))): - print(f" {c:+.12e}, // * x^{e}") -print("};\n") -print() -print(f"// Coefficients with max error: {max_abs_error:.4e}") -for i, (e, c) in enumerate(zip(exponents, coeffs)): - print(f"c.push_back({c:+.12e}f);") print() print("exponent:", exponents) +if args.no_gui: + exit() + import matplotlib.pyplot as plt -fig, ax = plt.subplots(5, figsize=(5.5, 8)) -ax[0].set_title("Comparison of exact and approximate " + args.func) +fig, ax = plt.subplots(2, 4, figsize=(12, 6)) +ax = ax.flatten() +ax[0].set_title("Comparison of exact\nand approximate " + args.func) ax[0].plot(X, target, label=args.func) ax[0].plot(X, y_hat, label='approx') ax[0].grid() ax[0].set_xlim(lower, upper) ax[0].legend() -ax[1].set_title("Absolute error in log-scale") -ax[1].semilogy(X, np.abs(y_hat_init - target), label='abs error (init)') -ax[1].semilogy(X, np.abs(diff), label='abs error (final)') -ax[1].axhline(np.amax(np.abs(y_hat_init - target)), linestyle=':', c='C0') -ax[1].axhline(np.amax(np.abs(diff)), linestyle=':', c='C1') +ax[1].set_title("Error") +ax[1].axhline(0, linestyle='-', c='k', linewidth=1) +ax[1].plot(X, init_y_hat - target, label='init') +ax[1].plot(X, y_hat - target, label='final') ax[1].grid() ax[1].set_xlim(lower, upper) ax[1].legend() -ax[2].set_title("Error") -ax[2].plot(X, y_hat_init - target, label='init diff') -ax[2].plot(X, y_hat - target, label='final diff') +ax[2].set_title("Absolute error\n(log-scale)") +ax[2].semilogy(X, init_abs_error, label='init') +ax[2].semilogy(X, abs_diff, label='final') +ax[2].axhline(np.amax(init_abs_error), linestyle=':', c='C0') +ax[2].axhline(np.amax(abs_diff), linestyle=':', c='C1') ax[2].grid() ax[2].set_xlim(lower, upper) ax[2].legend() -ax[3].set_title("LstSq Weight (log-scale)") -ax[3].semilogy(X, norm_weight, label='weight') +ax[3].set_title("Maximal Absolute Error\nprogression during\noptimization") +ax[3].semilogx(1 + np.arange(loss_history.shape[0]), loss_history[:,1]) +ax[3].set_xlim(1, loss_history.shape[0] + 1) +ax[3].axhline(y=loss_history[0,1], linestyle=':', color='k') ax[3].grid() -ax[3].set_xlim(lower, upper) -ax[3].legend() -ax[4].set_title("Maximal Absolute Error progression during optimization") -ax[4].semilogx(1 + np.arange(loss_history.shape[0]), loss_history[:,1], label='MaxAE') -ax[4].set_xlim(1, loss_history.shape[0] + 1) -ax[4].axhline(y=loss_history[0,1], linestyle=':', color='k') +ax[5].set_title("ULP distance") +ax[5].axhline(0, linestyle='-', c='k', linewidth=1) +ax[5].plot(X, init_ulp_error, label='init') +ax[5].plot(X, ulp_error, label='final') +ax[5].grid() +ax[5].set_xlim(lower, upper) +ax[5].legend() + + +ax[6].set_title("Absolute ULP distance\n(log-scale)") +ax[6].semilogy(X, init_abs_ulp_error, label='init') +ax[6].semilogy(X, abs_ulp_error, label='final') +ax[6].axhline(np.amax(init_abs_ulp_error), linestyle=':', c='C0') +ax[6].axhline(np.amax(abs_ulp_error), linestyle=':', c='C1') +ax[6].grid() +ax[6].set_xlim(lower, upper) +ax[6].legend() + +ax[7].set_title("Maximal ULP Error\nprogression during\noptimization") +ax[7].loglog(1 + np.arange(loss_history.shape[0]), loss_history[:,2]) +ax[7].set_xlim(1, loss_history.shape[0] + 1) +ax[7].axhline(y=loss_history[0,2], linestyle=':', color='k') +ax[7].grid() + +ax[4].set_title("LstSq Weight\n(log-scale)") +ax[4].semilogy(X, norm_weight, label='weight') ax[4].grid() +ax[4].set_xlim(lower, upper) ax[4].legend() + plt.tight_layout() plt.show() diff --git a/test/performance/fast_arctan.cpp b/test/performance/fast_arctan.cpp index ecb5bced2661..52cfeb6c36bd 100644 --- a/test/performance/fast_arctan.cpp +++ b/test/performance/fast_arctan.cpp @@ -14,10 +14,6 @@ int main(int argc, char **argv) { printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n"); return 0; } - if (target.has_feature(Target::WebGPU)) { - printf("[SKIP] WebGPU seems to perform bad, and fast_atan is not really faster in all scenarios.\n"); - return 0; - } Var x, y; const int test_w = 256; @@ -27,7 +23,7 @@ int main(int argc, char **argv) { Expr t1 = y / float(test_h); // To make sure we time mostely the computation of the arctan, and not memory bandwidth, // we will compute many arctans per output and sum them. In my testing, GPUs suffer more - // from bandwith with this test, so we give it more arctangenses to compute per output. + // from bandwith with this test, so we give it more arctangents to compute per output. const int test_d = target.has_gpu_feature() ? 1024 : 64; RDom rdom{0, test_d}; Expr off = rdom / float(test_d) - 0.5f; @@ -49,24 +45,30 @@ int main(int argc, char **argv) { atan2_ref.vectorize(x, 8); } - Tools::BenchmarkConfig cfg = {0.2, 1.0}; double scale = 1e9 / (double(test_w) * (test_h * test_d)); + Buffer atan_out(test_w, test_h); + Buffer atan2_out(test_w, test_h); + atan_ref.compile_jit(); + atan2_ref.compile_jit(); // clang-format off - double t_atan = scale * benchmark([&]() { atan_ref.realize({test_w, test_h}); }, cfg); - double t_atan2 = scale * benchmark([&]() { atan2_ref.realize({test_w, test_h}); }, cfg); + double t_atan = scale * benchmark([&]() { atan_ref.realize( atan_out); atan_out.device_sync(); }); + double t_atan2 = scale * benchmark([&]() { atan2_ref.realize(atan2_out); atan2_out.device_sync(); }); // clang-format on struct Prec { ApproximationPrecision precision; - float epsilon; + const char *name; double atan_time{0.0f}; double atan2_time{0.0f}; } precisions_to_test[] = { - {ApproximationPrecision::MAE_1e_2, 1e-2f}, - {ApproximationPrecision::MAE_1e_3, 1e-3f}, - {ApproximationPrecision::MAE_1e_4, 1e-4f}, - {ApproximationPrecision::MAE_1e_5, 1e-5f}, - {ApproximationPrecision::MAE_1e_6, 1e-6f}}; + {ApproximationPrecision::MULPE_Poly2, "Poly2"}, + {ApproximationPrecision::MULPE_Poly3, "Poly3"}, + {ApproximationPrecision::MULPE_Poly4, "Poly4"}, + {ApproximationPrecision::MULPE_Poly5, "Poly5"}, + {ApproximationPrecision::MULPE_Poly6, "Poly6"}, + {ApproximationPrecision::MULPE_Poly7, "Poly7"}, + {ApproximationPrecision::MULPE_Poly8, "Poly8"}, + }; for (Prec &precision : precisions_to_test) { Func atan_f{"fast_atan"}, atan2_f{"fast_atan2"}; @@ -85,25 +87,27 @@ int main(int argc, char **argv) { atan2_f.vectorize(x, 8); } + atan_f.compile_jit(); + atan2_f.compile_jit(); // clang-format off - double t_fast_atan = scale * benchmark([&]() { atan_f.realize({test_w, test_h}); }, cfg); - double t_fast_atan2 = scale * benchmark([&]() { atan2_f.realize({test_w, test_h}); }, cfg); + double t_fast_atan = scale * benchmark([&]() { atan_f.realize( atan_out); atan_out.device_sync(); }); + double t_fast_atan2 = scale * benchmark([&]() { atan2_f.realize(atan2_out); atan2_out.device_sync(); }); // clang-format on precision.atan_time = t_fast_atan; precision.atan2_time = t_fast_atan2; } - printf(" atan: %f ns per atan\n", t_atan); + printf(" atan: %f ns per atan\n", t_atan); for (const Prec &precision : precisions_to_test) { - printf(" fast_atan (MAE %.0e): %f ns per atan (%4.1f%% faster) [per invokation: %f ms]\n", - precision.epsilon, precision.atan_time, 100.0f * (1.0f - precision.atan_time / t_atan), + printf(" fast_atan (%s): %f ns per atan (%4.1f%% faster) [per invokation: %f ms]\n", + precision.name, precision.atan_time, 100.0f * (1.0f - precision.atan_time / t_atan), precision.atan_time / scale * 1e3); } printf("\n"); - printf(" atan2: %f ns per atan2\n", t_atan2); + printf(" atan2: %f ns per atan2\n", t_atan2); for (const Prec &precision : precisions_to_test) { - printf(" fast_atan2 (MAE %.0e): %f ns per atan2 (%4.1f%% faster) [per invokation: %f ms]\n", - precision.epsilon, precision.atan2_time, 100.0f * (1.0f - precision.atan2_time / t_atan2), + printf(" fast_atan2 (%s): %f ns per atan2 (%4.1f%% faster) [per invokation: %f ms]\n", + precision.name, precision.atan2_time, 100.0f * (1.0f - precision.atan2_time / t_atan2), precision.atan2_time / scale * 1e3); }