diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 1f51afe2789d..36d6fe2f4f3b 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -263,6 +263,7 @@ find_library(MPFR_LIB_PATH mpfr) CHECK_INCLUDE_FILE("mpfr.h" HAS_MPFR_H) message("MPFR lib: " ${MPFR_LIB_PATH}) message("MPFR header: " ${HAS_MPFR_H}) +link_libraries(mpfr) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolutionExpander.h" "${INPUT_TEXT}") diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index cedaa1d19c3f..d81f9bb0e0ba 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1013,7 +1013,8 @@ bool isValuePotentiallyUsedAsPointer(llvm::Value *val) { for (auto u : cur->users()) { if (isa(u)) return true; - if (!cast(u)->mayReadOrWriteMemory()) { + if (isa(u) || + !cast(u)->mayReadOrWriteMemory()) { todo.push_back(u); continue; } diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 7e75716ec5f2..aacdb3bdb946 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -44,12 +44,33 @@ set(LLVM_LINK_COMPONENTS Demangle) file(GLOB ENZYME_SRC CONFIGURE_DEPENDS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp" ) -list(REMOVE_ITEM ENZYME_SRC "eopt.cpp") +list(REMOVE_ITEM ENZYME_SRC "eopt.cpp" "Herbie.cpp") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp TypeAnalysis/RustDebugInfo.cpp) +# set(ENZYME_LINK_TARGETS) +set(ENZYME_ENABLE_HERBIE 0 CACHE BOOL "Enable Herbie") + +if(ENZYME_ENABLE_HERBIE) + include(ExternalProject) + ExternalProject_Add(herbie + GIT_REPOSITORY https://github.com/herbie-fp/herbie + GIT_TAG 5e640bd324ece7105804c7842c6026fd92808890 + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND make egg-herbie && make update && raco exe -o herbie --orig-exe --embed-dlls --vv src/herbie.rkt + BUILD_IN_SOURCE true + INSTALL_COMMAND COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/herbie/install + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_BINARY_DIR}/herbie-prefix/src/herbie/herbie ${CMAKE_CURRENT_BINARY_DIR}/herbie/install/herbie + ) + list(APPEND ENZYME_SRC Herbie.cpp) + add_compile_definitions(ENZYME_ENABLE_HERBIE=1) + set_source_files_properties(Herbie.cpp PROPERTIES COMPILE_DEFINITIONS HERBIE_BINARY="${CMAKE_CURRENT_BINARY_DIR}/herbie/install/herbie") +endif() + + # on windows `PLUGIN_TOOL` doesn't link against LLVM.dll if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB) add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR} @@ -71,6 +92,7 @@ if (${Clang_FOUND}) intrinsics_gen LINK_COMPONENTS LLVM + ${ENZYME_LINK_TARGETS} ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) endif() @@ -185,8 +207,8 @@ target_link_options(LLDEnzymePrintFlags INTERFACE "SHELL: -Wl,-mllvm -Wl,-enzyme add_library(LLDEnzymeNoStrictAliasingFlags INTERFACE) target_link_options(LLDEnzymeNoStrictAliasingFlags INTERFACE "SHELL: -Wl,-mllvm -Wl,-enzyme-strict-aliasing=0") -# this custom target exists to prevent CMake from incorrectly assuming that -# targets that link depend on LLDEnzyme-XX can be built at the same time or +# this custom target exists to prevent CMake from incorrectly assuming that +# targets that link depend on LLDEnzyme-XX can be built at the same time or # before LLDEnzyme-XX has finished. add_custom_target(LLDEnzymeDummy "" DEPENDS LLDEnzyme-${LLVM_VERSION_MAJOR}) add_dependencies(LLDEnzymeFlags LLDEnzymeDummy) @@ -194,8 +216,8 @@ add_dependencies(LLDEnzymeFlags LLDEnzymeDummy) add_library(ClangEnzymeFlags INTERFACE) target_compile_options(ClangEnzymeFlags INTERFACE "-fplugin=$") -# this custom target exists to prevent CMake from incorrectly assuming that -# targets that link depend on ClangEnzyme-XX can be built at the same time or +# this custom target exists to prevent CMake from incorrectly assuming that +# targets that link depend on ClangEnzyme-XX can be built at the same time or # before ClangEnzyme-XX has finished. add_custom_target(ClangEnzymeDummy "" DEPENDS ClangEnzyme-${LLVM_VERSION_MAJOR}) add_dependencies(ClangEnzymeFlags ClangEnzymeDummy) diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index eba0de11f54d..db9471fc16f5 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -107,11 +107,13 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( std::string prefix; switch (mode) { - case DerivativeMode::ForwardModeError: case DerivativeMode::ForwardMode: case DerivativeMode::ForwardModeSplit: prefix = "fwddiffe"; break; + case DerivativeMode::ForwardModeError: + prefix = "fwderr"; + break; case DerivativeMode::ReverseModeCombined: case DerivativeMode::ReverseModeGradient: prefix = "diffe"; diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 56c49039c86f..330e3e32da59 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -123,7 +123,8 @@ inline bool is_value_needed_in_reverse( } } } - if (gutils->mode == DerivativeMode::ForwardModeError && + if ((getLogFunction(gutils->oldFunc->getParent(), "enzymeLogValue") || + gutils->mode == DerivativeMode::ForwardModeError) && !gutils->isConstantValue(const_cast(inst))) { if (EnzymePrintDiffUse) llvm::errs() diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index c33f66ff95b3..ebe6137ac3ac 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -3272,6 +3272,9 @@ AnalysisKey EnzymeNewPM::Key; #include "ActivityAnalysisPrinter.h" #include "JLInstSimplify.h" #include "PreserveNVVM.h" +#ifdef ENZYME_ENABLE_HERBIE +#include "Herbie.h" +#endif #include "TypeAnalysis/TypeAnalysisPrinter.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" @@ -3385,6 +3388,10 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { OptimizerPM.addPass(llvm::SROAPass()); #endif MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM))); +#ifdef ENZYME_ENABLE_HERBIE + if (EnzymeEnableFPOpt) + MPM.addPass(FPOptNewPM()); +#endif MPM.addPass(EnzymeNewPM(/*PostOpt=*/true)); MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false)); #if LLVM_VERSION_MAJOR >= 16 @@ -3669,6 +3676,12 @@ void registerEnzyme(llvm::PassBuilder &PB) { MPM.addPass(EnzymeNewPM()); return true; } +#ifdef ENZYME_ENABLE_HERBIE + if (Name == "fp-opt") { + MPM.addPass(FPOptNewPM()); + return true; + } +#endif if (Name == "preserve-nvvm") { MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); return true; diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 3cd0203c08a9..09c08fef69a8 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -1417,6 +1417,30 @@ Function *PreProcessCache::preprocessForClone(Function *F, /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, Returns, "", nullptr); } + if (mode == DerivativeMode::ForwardModeError || + mode == DerivativeMode::ReverseModeCombined || + mode == DerivativeMode::ReverseModeGradient) { + if (getLogFunction(F->getParent(), "enzymeLogError") || + getLogFunction(F->getParent(), "enzymeLogValue") || + getLogFunction(F->getParent(), "enzymeLogGrad")) { + for (const auto &pair : VMap) { + if (auto *before = dyn_cast(pair.first)) { + if (!before->getType()->isFloatingPointTy()) { + continue; + } + auto *after = cast(pair.second); + after->setMetadata("enzyme_active", + MDNode::get(after->getContext(), None)); + after->setMetadata( + "enzyme_preprocess_origin", + MDTuple::get(after->getContext(), + {ConstantAsMetadata::get(ConstantInt::get( + Type::getInt64Ty(after->getContext()), + reinterpret_cast(before)))})); + } + } + } + } CloneOrigin[NewF] = F; NewF->setAttributes(F->getAttributes()); if (EnzymeNoAlias) @@ -1799,7 +1823,8 @@ Function *PreProcessCache::preprocessForClone(Function *F, FAM.invalidate(*NewF, PA); } - if (mode != DerivativeMode::ForwardMode) + if (mode != DerivativeMode::ForwardMode && + mode != DerivativeMode::ForwardModeError) ReplaceReallocs(NewF); { @@ -1835,7 +1860,8 @@ Function *PreProcessCache::preprocessForClone(Function *F, PA.preserve(); } - if (mode != DerivativeMode::ForwardMode) + if (mode != DerivativeMode::ForwardMode && + mode != DerivativeMode::ForwardModeError) ReplaceReallocs(NewF); if (mode == DerivativeMode::ReverseModePrimal || diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp new file mode 100644 index 000000000000..7a681b03fa72 --- /dev/null +++ b/enzyme/Enzyme/Herbie.cpp @@ -0,0 +1,4998 @@ +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +#include "llvm/Analysis/TargetTransformInfo.h" + +#include "llvm/Demangle/Demangle.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" +#include + +#include "llvm/Pass.h" + +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Herbie.h" +#include "Utils.h" + +using namespace llvm; +#ifdef DEBUG_TYPE +#undef DEBUG_TYPE +#endif +#define DEBUG_TYPE "fp-opt" + +extern "C" { +cl::opt EnzymeEnableFPOpt("enzyme-enable-fpopt", cl::init(false), + cl::Hidden, cl::desc("Run the FPOpt pass")); +static cl::opt + EnzymePrintFPOpt("enzyme-print-fpopt", cl::init(false), cl::Hidden, + cl::desc("Enable Enzyme to print FPOpt info")); +static cl::opt + EnzymePrintHerbie("enzyme-print-herbie", cl::init(false), cl::Hidden, + cl::desc("Enable Enzyme to print Herbie expressions")); +static cl::opt + FPOptLogPath("fpopt-log-path", cl::init(""), cl::Hidden, + cl::desc("Which log to use in the FPOpt pass")); +static cl::opt + FPOptCostModelPath("fpopt-cost-model-path", cl::init(""), cl::Hidden, + cl::desc("Use a custom cost model in the FPOpt pass")); +static cl::opt FPOptTargetFuncRegex( + "fpopt-target-func-regex", cl::init(".*"), cl::Hidden, + cl::desc("Regex pattern to match target functions in the FPOpt pass")); +static cl::opt FPOptEnableHerbie( + "fpopt-enable-herbie", cl::init(true), cl::Hidden, + cl::desc("Use Herbie to rewrite floating-point expressions")); +static cl::opt FPOptEnablePT( + "fpopt-enable-pt", cl::init(false), cl::Hidden, + cl::desc("Consider precision changes of floating-point expressions")); +static cl::opt HerbieNumThreads("herbie-num-threads", cl::init(1), + cl::Hidden, + cl::desc("Number of threads Herbie uses")); +static cl::opt HerbieTimeout("herbie-timeout", cl::init(120), cl::Hidden, + cl::desc("Herbie's timeout to use for each " + "candidate expressions.")); +static cl::opt + FPOptCachePath("fpopt-cache-path", cl::init(""), cl::Hidden, + cl::desc("Experimental: path to cache Herbie results")); +static cl::opt + HerbieNumPoints("herbie-num-pts", cl::init(1024), cl::Hidden, + cl::desc("Number of input points Herbie uses to evaluate " + "candidate expressions.")); +static cl::opt HerbieNumIters( + "herbie-num-iters", cl::init(6), cl::Hidden, + cl::desc("Number of times Herbie attempts to improve accuracy.")); +static cl::opt HerbieDisableNumerics( + "herbie-disable-numerics", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie rewrite rules that produce numerical shorthands " + "expm1, log1p, fma, and hypot")); +static cl::opt + HerbieDisableTaylor("herbie-disable-taylor", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie's series expansion")); +static cl::opt HerbieDisableSetupSimplify( + "herbie-disable-setup-simplify", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from pre-simplifying expressions")); +static cl::opt HerbieDisableGenSimplify( + "herbie-disable-gen-simplify", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from simplifying expressions " + "during the main improvement loop")); +static cl::opt HerbieDisableRegime( + "herbie-disable-regime", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from branching between expressions candidates")); +static cl::opt HerbieDisableBranchExpr( + "herbie-disable-branch-expr", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from branching on expressions")); +static cl::opt HerbieDisableAvgError( + "herbie-disable-avg-error", cl::init(false), cl::Hidden, + cl::desc("Make Herbie choose the candidates with the least maximum error")); +static cl::opt FPOptEnableSolver( + "fpopt-enable-solver", cl::init(false), cl::Hidden, + cl::desc("Use the solver to select desirable rewrite candidates; when " + "disabled, apply all Herbie's first choices")); +static cl::opt FPOptSolverType("fpopt-solver-type", cl::init("dp"), + cl::Hidden, + cl::desc("Which solver to use")); +static cl::opt FPOptShowTable( + "fpopt-show-table", cl::init(false), cl::Hidden, + cl::desc( + "Print the full DP table (highly verbose for large applications)")); +static cl::opt FPOptComputationCostBudget( + "fpopt-comp-cost-budget", cl::init(100000000000L), cl::Hidden, + cl::desc("The maximum computation cost budget for the solver")); +static cl::opt FPOptMaxFPCCDepth( + "fpopt-max-fpcc-depth", cl::init(10), cl::Hidden, + cl::desc("The maximum depth of a floating-point connected component")); +static cl::opt + FPOptRandomSeed("fpopt-random-seed", cl::init(239778888), cl::Hidden, + cl::desc("The random seed used in the FPOpt pass")); +static cl::opt + FPOptNumSamples("fpopt-num-samples", cl::init(1024), cl::Hidden, + cl::desc("Number of sampled points for input hypercube")); +static cl::opt + FPOptMaxMPFRPrec("fpopt-max-mpfr-prec", cl::init(1024), cl::Hidden, + cl::desc("Max precision for MPFR gold value computation")); +static cl::opt + FPOptWidenRange("fpopt-widen-range", cl::init(1), cl::Hidden, + cl::desc("Ablation study only: widen the range of input " + "hypercube by this factor")); +static cl::opt FPOptEarlyPrune( + "fpopt-early-prune", cl::init(false), cl::Hidden, + cl::desc("Prune dominated candidates in expression transformation phases")); +static cl::opt FPOptCostDominanceThreshold( + "fpopt-cost-dom-thres", cl::init(0.05), cl::Hidden, + cl::desc("The threshold for cost dominance in DP solver")); +static cl::opt FPOptAccuracyDominanceThreshold( + "fpopt-acc-dom-thres", cl::init(0.05), cl::Hidden, + cl::desc("The threshold for accuracy dominance in DP solver")); +} + +class FPNode { +public: + enum class NodeType { Node, LLValue, Const }; + +private: + const NodeType ntype; + +public: + std::string op; + std::string dtype; + std::string symbol; + SmallVector, 2> operands; + double grad; + double geometricAvg; + unsigned executions; + + explicit FPNode(const std::string &op, const std::string &dtype) + : ntype(NodeType::Node), op(op), dtype(dtype) {} + explicit FPNode(NodeType ntype, const std::string &op, + const std::string &dtype) + : ntype(ntype), op(op), dtype(dtype) {} + virtual ~FPNode() = default; + + NodeType getType() const { return ntype; } + + void addOperand(std::shared_ptr operand) { + operands.push_back(operand); + } + + virtual bool hasSymbol() const { + std::string msg = "Unexpected invocation of `hasSymbol` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); + } + + virtual std::string toFullExpression( + std::unordered_map> &valueToNodeMap) { + std::string msg = "Unexpected invocation of `toFullExpression` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); + } + + unsigned getMPFRPrec() const { + if (dtype == "f16") + return 11; + if (dtype == "f32") + return 24; + if (dtype == "f64") + return 53; + std::string msg = + "getMPFRPrec: operator " + op + " has unknown dtype " + dtype; + llvm_unreachable(msg.c_str()); + } + + virtual void markAsInput() { + std::string msg = "Unexpected invocation of `markAsInput` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); + } + + virtual void updateBounds(double lower, double upper) { + std::string msg = "Unexpected invocation of `updateBounds` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); + } + virtual double getLowerBound() const { + std::string msg = "Unexpected invocation of `getLowerBound` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); + } + virtual double getUpperBound() const { + std::string msg = "Unexpected invocation of `getUpperBound` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); + } + + virtual Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr) { + Module *M = builder.GetInsertBlock()->getModule(); + + if (op == "if") { + Value *condValue = operands[0]->getLLValue(builder, VMap); + auto IP = builder.GetInsertPoint(); + + Instruction *Then, *Else; + SplitBlockAndInsertIfThenElse(condValue, &*IP, &Then, &Else); + + Then->getParent()->setName("herbie.then"); + builder.SetInsertPoint(Then); + Value *ThenVal = operands[1]->getLLValue(builder, VMap); + if (Instruction *I = dyn_cast(ThenVal)) { + I->setName("herbie.then_val"); + } + + Else->getParent()->setName("herbie.else"); + builder.SetInsertPoint(Else); + Value *ElseVal = operands[2]->getLLValue(builder, VMap); + if (Instruction *I = dyn_cast(ElseVal)) { + I->setName("herbie.else_val"); + } + + builder.SetInsertPoint(&*IP); + auto Phi = builder.CreatePHI(ThenVal->getType(), 2); + Phi->addIncoming(ThenVal, Then->getParent()); + Phi->addIncoming(ElseVal, Else->getParent()); + Phi->setName("herbie.merge"); + + return Phi; + } + + SmallVector operandValues; + for (auto operand : operands) { + operandValues.push_back(operand->getLLValue(builder, VMap)); + } + + Value *val = nullptr; + + if (op == "neg") { + val = builder.CreateFNeg(operandValues[0], "herbie.neg"); + } else if (op == "+") { + val = + builder.CreateFAdd(operandValues[0], operandValues[1], "herbie.add"); + } else if (op == "-") { + val = + builder.CreateFSub(operandValues[0], operandValues[1], "herbie.sub"); + } else if (op == "*") { + val = + builder.CreateFMul(operandValues[0], operandValues[1], "herbie.mul"); + } else if (op == "/") { + val = + builder.CreateFDiv(operandValues[0], operandValues[1], "herbie.div"); + } else if (op == "sin") { + val = builder.CreateUnaryIntrinsic(Intrinsic::sin, operandValues[0], + nullptr, "herbie.sin"); + } else if (op == "cos") { + val = builder.CreateUnaryIntrinsic(Intrinsic::cos, operandValues[0], + nullptr, "herbie.cos"); + } else if (op == "tan") { +#if LLVM_VERSION_MAJOR >= 16 // TODO: Double check version + val = builder.CreateUnaryIntrinsic(Intrinsic::tan, operandValues[0], + "herbie.tan"); +#else + // Using std::tan(f) for lower versions of LLVM. + auto *Ty = operandValues[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "tan" : "tanf"; + llvm::Function *tanFunc = M->getFunction(funcName); + if (!tanFunc) { + auto *funcTy = FunctionType::get(Ty, {Ty}, false); + tanFunc = + Function::Create(funcTy, Function::ExternalLinkage, funcName, M); + } + if (tanFunc) { + val = builder.CreateCall(tanFunc, {operandValues[0]}, "herbie.tan"); + } else { + std::string msg = + "Failed to find or declare " + funcName + " in the module."; + llvm_unreachable(msg.c_str()); + } + +#endif + } else if (op == "exp") { + val = builder.CreateUnaryIntrinsic(Intrinsic::exp, operandValues[0], + nullptr, "herbie.exp"); + } else if (op == "expm1") { + auto *Ty = operandValues[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "expm1" : "expm1f"; + llvm::Function *expm1Func = M->getFunction(funcName); + if (!expm1Func) { + auto *funcTy = FunctionType::get(Ty, {Ty}, false); + expm1Func = + Function::Create(funcTy, Function::ExternalLinkage, funcName, M); + } + if (expm1Func) { + val = builder.CreateCall(expm1Func, {operandValues[0]}, "herbie.expm1"); + } else { + std::string msg = "Failed to find or declare " + funcName + + " in the module. Consider disabling Herbie rules for " + "numerics (-herbie-disable-numerics)."; + llvm_unreachable(msg.c_str()); + } + } else if (op == "log") { + val = builder.CreateUnaryIntrinsic(Intrinsic::log, operandValues[0], + nullptr, "herbie.log"); + } else if (op == "log1p") { + auto *Ty = operandValues[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "log1p" : "log1pf"; + llvm::Function *log1pFunc = M->getFunction(funcName); + if (!log1pFunc) { + auto *funcTy = FunctionType::get(Ty, {Ty}, false); + log1pFunc = + Function::Create(funcTy, Function::ExternalLinkage, funcName, M); + } + if (log1pFunc) { + val = builder.CreateCall(log1pFunc, {operandValues[0]}, "herbie.log1p"); + } else { + std::string msg = + "Failed to find or declare log1p in the module. Consider disabling " + "Herbie rules for numerics (-herbie-disable-numerics)."; + llvm_unreachable(msg.c_str()); + } + } else if (op == "sqrt") { + val = builder.CreateUnaryIntrinsic(Intrinsic::sqrt, operandValues[0], + nullptr, "herbie.sqrt"); + } else if (op == "cbrt") { + auto *Ty = operandValues[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "cbrt" : "cbrtf"; + llvm::Function *cbrtFunc = M->getFunction(funcName); + if (!cbrtFunc) { + auto *funcTy = FunctionType::get(Ty, {Ty}, false); + cbrtFunc = + Function::Create(funcTy, Function::ExternalLinkage, funcName, M); + } + if (cbrtFunc) { + val = builder.CreateCall(cbrtFunc, {operandValues[0]}, "herbie.cbrt"); + } else { + std::string msg = + "Failed to find or declare " + funcName + + " in the module. Consider disabling " + "Herbie rules for numerics (-herbie-disable-numerics)."; + llvm_unreachable(msg.c_str()); + } + } else if (op == "pow") { + val = builder.CreateBinaryIntrinsic(Intrinsic::pow, operandValues[0], + operandValues[1], nullptr, + "herbie.pow"); + } else if (op == "fma") { + val = builder.CreateIntrinsic( + Intrinsic::fma, {operandValues[0]->getType()}, + {operandValues[0], operandValues[1], operandValues[2]}, nullptr, + "herbie.fma"); + } else if (op == "fabs") { + val = builder.CreateUnaryIntrinsic(Intrinsic::fabs, operandValues[0], + nullptr, "herbie.fabs"); + } else if (op == "hypot") { + auto *Ty = operandValues[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "hypot" : "hypotf"; + llvm::Function *hypotFunc = M->getFunction(funcName); + if (!hypotFunc) { + auto *funcTy = FunctionType::get(Ty, {Ty, Ty}, false); + hypotFunc = + Function::Create(funcTy, Function::ExternalLinkage, funcName, M); + } + if (hypotFunc) { + val = builder.CreateCall( + hypotFunc, {operandValues[0], operandValues[1]}, "herbie.hypot"); + } else { + std::string msg = + "Failed to find or declare " + funcName + + " in the module. Consider disabling " + "Herbie rules for numerics (-herbie-disable-numerics)."; + llvm_unreachable(msg.c_str()); + } + } else if (op == "==") { + val = builder.CreateFCmpOEQ(operandValues[0], operandValues[1], + "herbie.if.eq"); + } else if (op == "!=") { + val = builder.CreateFCmpONE(operandValues[0], operandValues[1], + "herbie.if.ne"); + } else if (op == "<") { + val = builder.CreateFCmpOLT(operandValues[0], operandValues[1], + "herbie.if.lt"); + } else if (op == ">") { + val = builder.CreateFCmpOGT(operandValues[0], operandValues[1], + "herbie.if.gt"); + } else if (op == "<=") { + val = builder.CreateFCmpOLE(operandValues[0], operandValues[1], + "herbie.if.le"); + } else if (op == ">=") { + val = builder.CreateFCmpOGE(operandValues[0], operandValues[1], + "herbie.if.ge"); + } else if (op == "and") { + val = builder.CreateAnd(operandValues[0], operandValues[1], + "herbie.if.and"); + } else if (op == "or") { + val = + builder.CreateOr(operandValues[0], operandValues[1], "herbie.if.or"); + } else if (op == "not") { + val = builder.CreateNot(operandValues[0], "herbie.if.not"); + } else if (op == "TRUE") { + val = ConstantInt::getTrue(builder.getContext()); + } else if (op == "FALSE") { + val = ConstantInt::getFalse(builder.getContext()); + } else { + std::string msg = "FPNode getLLValue: Unexpected operator " + op; + llvm_unreachable(msg.c_str()); + } + + return val; + } +}; + +// Represents a true LLVM Value +class FPLLValue : public FPNode { + double lb = std::numeric_limits::infinity(); + double ub = -std::numeric_limits::infinity(); + bool input = false; // Whether `llvm::Value` is an input of an FPCC + +public: + Value *value; + + explicit FPLLValue(Value *value, const std::string &op, + const std::string &dtype) + : FPNode(NodeType::LLValue, op, dtype), value(value) {} + + bool hasSymbol() const override { return !symbol.empty(); } + + std::string toFullExpression( + std::unordered_map> &valueToNodeMap) + override { + if (input) { + assert(hasSymbol() && "FPLLValue has no symbol!"); + return symbol; + } else { + assert(!operands.empty() && "FPNode has no operands!"); + std::string expr = "(" + op; + for (auto operand : operands) { + expr += " " + operand->toFullExpression(valueToNodeMap); + } + expr += ")"; + return expr; + } + } + + void markAsInput() override { input = true; } + + void updateBounds(double lower, double upper) override { + lb = std::min(lb, lower); + ub = std::max(ub, upper); + if (EnzymePrintFPOpt) + llvm::errs() << "Updated bounds for " << *value << ": [" << lb << ", " + << ub << "]\n"; + } + + double getLowerBound() const override { return lb; } + double getUpperBound() const override { return ub; } + + Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr) override { + if (VMap) { + assert(VMap->count(value) && "FPLLValue not found in passed-in VMap!"); + return VMap->lookup(value); + } + return value; + } + + static bool classof(const FPNode *N) { + return N->getType() == NodeType::LLValue; + } +}; + +double stringToDouble(const std::string &str) { + char *end; + errno = 0; + double result = std::strtod(str.c_str(), &end); + + if (errno == ERANGE) { + if (result == HUGE_VAL) { + result = std::numeric_limits::infinity(); + } else if (result == -HUGE_VAL) { + result = -std::numeric_limits::infinity(); + } + } + + return result; // Denormalized values are fine +} + +class FPConst : public FPNode { + std::string strValue; + +public: + explicit FPConst(const std::string &strValue, const std::string &dtype) + : FPNode(NodeType::Const, "__const", dtype), strValue(strValue) {} + + std::string toFullExpression( + std::unordered_map> &valueToNodeMap) + override { + return strValue; + } + + bool hasSymbol() const override { + std::string msg = "Unexpected invocation of `hasSymbol` on an FPConst"; + llvm_unreachable(msg.c_str()); + } + + void markAsInput() override { return; } + + void updateBounds(double lower, double upper) override { return; } + + double getLowerBound() const override { + if (strValue == "+inf.0") { + return std::numeric_limits::infinity(); + } else if (strValue == "-inf.0") { + return -std::numeric_limits::infinity(); + } + + double constantValue; + size_t div = strValue.find('/'); + + if (div != std::string::npos) { + std::string numerator = strValue.substr(0, div); + std::string denominator = strValue.substr(div + 1); + double num = stringToDouble(numerator); + double denom = stringToDouble(denominator); + + constantValue = num / denom; + } else { + constantValue = stringToDouble(strValue); + } + + return constantValue; + } + + double getUpperBound() const override { return getLowerBound(); } + + virtual Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr) override { + Type *Ty; + if (dtype == "f64") { + Ty = builder.getDoubleTy(); + } else if (dtype == "f32") { + Ty = builder.getFloatTy(); + } else { + std::string msg = "FPConst getValue: Unexpected dtype: " + dtype; + llvm_unreachable(msg.c_str()); + } + if (strValue == "+inf.0") { + return ConstantFP::getInfinity(Ty, false); + } else if (strValue == "-inf.0") { + return ConstantFP::getInfinity(Ty, true); + } + + double constantValue; + size_t div = strValue.find('/'); + + if (div != std::string::npos) { + std::string numerator = strValue.substr(0, div); + std::string denominator = strValue.substr(div + 1); + double num = stringToDouble(numerator); + double denom = stringToDouble(denominator); + + constantValue = num / denom; + } else { + constantValue = stringToDouble(strValue); + } + + // if (EnzymePrintFPOpt) + // llvm::errs() << "Returning " << strValue << " as " << dtype + // << " constant: " << constantValue << "\n"; + return ConstantFP::get(Ty, constantValue); + } + + static bool classof(const FPNode *N) { + return N->getType() == NodeType::Const; + } +}; + +void topoSort(const SetVector &insts, + SmallVectorImpl &instsSorted) { + SmallPtrSet visited; + SmallPtrSet onStack; + + std::function dfsVisit = [&](Instruction *I) { + if (visited.count(I)) + return; + visited.insert(I); + onStack.insert(I); + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &op : operands) { + if (isa(op)) { + Instruction *oI = cast(op); + if (insts.contains(oI)) { + if (onStack.count(oI)) { + llvm_unreachable( + "topoSort: Cycle detected in instruction dependencies!"); + } + dfsVisit(oI); + } + } + } + + onStack.erase(I); + instsSorted.push_back(I); + }; + + for (auto *I : insts) { + if (!visited.count(I)) { + dfsVisit(I); + } + } + + llvm::reverse(instsSorted); +} + +bool herbiable(const Value &Val) { + const Instruction *I = dyn_cast(&Val); + if (!I) + return false; + + switch (I->getOpcode()) { + case Instruction::FNeg: + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + return I->getType()->isFloatTy() || I->getType()->isDoubleTy(); + case Instruction::Call: { + const CallInst *CI = dyn_cast(I); + if (CI && CI->getCalledFunction() && + (CI->getType()->isFloatTy() || CI->getType()->isDoubleTy())) { + StringRef funcName = CI->getCalledFunction()->getName(); + return funcName.startswith("llvm.sin") || + funcName.startswith("llvm.cos") || + funcName.startswith("llvm.tan") || + funcName.startswith("llvm.exp") || + funcName.startswith("llvm.log") || + funcName.startswith("llvm.sqrt") || funcName.startswith("cbrt") || + funcName.startswith("llvm.pow") || + funcName.startswith("llvm.fma") || + funcName.startswith("llvm.fmuladd") || + funcName.startswith("hypot") || funcName.startswith("expm1") || + funcName.startswith("log1p"); + // llvm.fabs is deliberately excluded + } + return false; + } + default: + return false; + } +} + +enum class PrecisionChangeType { BF16, FP16, FP32, FP64, FP80, FP128 }; + +unsigned getMPFRPrec(PrecisionChangeType type) { + switch (type) { + case PrecisionChangeType::BF16: + return 8; + case PrecisionChangeType::FP16: + return 11; + case PrecisionChangeType::FP32: + return 24; + case PrecisionChangeType::FP64: + return 53; + case PrecisionChangeType::FP80: + return 64; + case PrecisionChangeType::FP128: + return 113; + default: + llvm_unreachable("Unsupported FP precision"); + } +} + +Type *getLLVMFPType(PrecisionChangeType type, LLVMContext &context) { + switch (type) { + case PrecisionChangeType::BF16: + return Type::getBFloatTy(context); + case PrecisionChangeType::FP16: + return Type::getHalfTy(context); + case PrecisionChangeType::FP32: + return Type::getFloatTy(context); + case PrecisionChangeType::FP64: + return Type::getDoubleTy(context); + case PrecisionChangeType::FP80: + return Type::getX86_FP80Ty(context); + case PrecisionChangeType::FP128: + return Type::getFP128Ty(context); + default: + llvm_unreachable("Unsupported FP precision"); + } +} + +PrecisionChangeType getPrecisionChangeType(Type *type) { + if (type->isHalfTy()) { + return PrecisionChangeType::BF16; + } else if (type->isHalfTy()) { + return PrecisionChangeType::FP16; + } else if (type->isFloatTy()) { + return PrecisionChangeType::FP32; + } else if (type->isDoubleTy()) { + return PrecisionChangeType::FP64; + } else if (type->isX86_FP80Ty()) { + return PrecisionChangeType::FP80; + } else if (type->isFP128Ty()) { + return PrecisionChangeType::FP128; + } else { + llvm_unreachable("Unsupported FP precision"); + } +} + +StringRef getPrecisionChangeTypeString(PrecisionChangeType type) { + switch (type) { + case PrecisionChangeType::BF16: + return "BF16"; + case PrecisionChangeType::FP16: + return "FP16"; + case PrecisionChangeType::FP32: + return "FP32"; + case PrecisionChangeType::FP64: + return "FP64"; + case PrecisionChangeType::FP80: + return "FP80"; + case PrecisionChangeType::FP128: + return "FP128"; + default: + return "Unknown PT type"; + } +} + +std::string getLibmFunctionForPrecision(StringRef funcName, Type *newType) { + static const std::unordered_set libmFunctions = { + "sin", "cos", "exp", "log", "sqrt", "tan", + "cbrt", "pow", "fma", "hypot", "expm1", "log1p"}; + + std::string baseName = funcName.str(); + if (baseName.back() == 'f' || baseName.back() == 'l') { + baseName.pop_back(); + } + + if (libmFunctions.count(baseName)) { + if (newType->isFloatTy()) { + return baseName + "f"; + } else if (newType->isDoubleTy()) { + return baseName; + } else if (newType->isFP128Ty() || newType->isX86_FP80Ty()) { + return baseName + "l"; + } + } + + return ""; +} + +// Floating-Point Connected Component +struct FPCC { + SetVector inputs; + SetVector outputs; + SetVector operations; + size_t outputs_rewritten = 0; + + FPCC() = default; + explicit FPCC(SetVector inputs, SetVector outputs, + SetVector operations) + : inputs(inputs), outputs(outputs), operations(operations) {} +}; + +struct PrecisionChange { + SetVector + nodes; // Only nodes with existing `llvm::Value`s can be changed + PrecisionChangeType oldType; + PrecisionChangeType newType; + + explicit PrecisionChange(SetVector &nodes, + PrecisionChangeType oldType, + PrecisionChangeType newType) + : nodes(nodes), oldType(oldType), newType(newType) {} +}; + +void changePrecision(Instruction *I, PrecisionChange &change, + MapVector &oldToNew) { + if (!herbiable(*I)) { + llvm_unreachable("Trying to tune an instruction is not herbiable"); + } + + IRBuilder<> Builder(I); + Builder.setFastMathFlags(I->getFastMathFlags()); + Type *newType = getLLVMFPType(change.newType, I->getContext()); + Value *newI = nullptr; + + if (isa(I) || isa(I)) { + SmallVector newOps; + for (auto &operand : I->operands()) { + Value *newOp = nullptr; + if (oldToNew.count(operand)) { + newOp = oldToNew[operand]; + } else { + if (Instruction *opInst = dyn_cast(operand)) { + IRBuilder<> OpBuilder(opInst->getParent(), + ++BasicBlock::iterator(opInst)); + OpBuilder.setFastMathFlags(I->getFastMathFlags()); + newOp = OpBuilder.CreateFPCast(operand, newType, "fpopt.fpcast"); + } else if (Argument *argOp = dyn_cast(operand)) { + BasicBlock &entry = argOp->getParent()->getEntryBlock(); + IRBuilder<> OpBuilder(&*entry.getFirstInsertionPt()); + OpBuilder.setFastMathFlags(I->getFastMathFlags()); + newOp = OpBuilder.CreateFPCast(operand, newType, "fpopt.fpcast"); + } else if (Constant *constOp = dyn_cast(operand)) { + newOp = ConstantExpr::getFPCast(constOp, newType); + } else { + llvm_unreachable("Unsupported operand type"); + } + oldToNew[operand] = newOp; + } + newOps.push_back(newOp); + } + newI = Builder.CreateNAryOp(I->getOpcode(), newOps); + } else if (auto *CI = dyn_cast(I)) { + SmallVector newArgs; + for (auto &arg : CI->args()) { + Value *newArg = nullptr; + if (oldToNew.count(arg)) { + newArg = oldToNew[arg]; + } else { + if (Instruction *argInst = dyn_cast(arg)) { + IRBuilder<> ArgBuilder(argInst->getParent(), + ++BasicBlock::iterator(argInst)); + ArgBuilder.setFastMathFlags(I->getFastMathFlags()); + newArg = ArgBuilder.CreateFPCast(arg, newType, "fpopt.fpcast"); + } else if (Argument *argArg = dyn_cast(arg)) { + BasicBlock &entry = argArg->getParent()->getEntryBlock(); + IRBuilder<> ArgBuilder(&*entry.getFirstInsertionPt()); + ArgBuilder.setFastMathFlags(I->getFastMathFlags()); + newArg = ArgBuilder.CreateFPCast(arg, newType, "fpopt.fpcast"); + } else if (Constant *constArg = dyn_cast(arg)) { + newArg = ConstantExpr::getFPCast(constArg, newType); + } else { + llvm_unreachable("Unsupported argument type"); + } + oldToNew[arg] = newArg; + } + newArgs.push_back(newArg); + } + auto *calledFunc = CI->getCalledFunction(); + if (calledFunc && calledFunc->isIntrinsic()) { + Intrinsic::ID intrinsicID = calledFunc->getIntrinsicID(); + if (intrinsicID != Intrinsic::not_intrinsic) { + Function *newFunc = + Intrinsic::getDeclaration(CI->getModule(), intrinsicID, {newType}); + newI = Builder.CreateCall(newFunc, newArgs); + } else { + llvm::errs() << "PT: Unknown intrinsic: " << *CI << "\n"; + llvm_unreachable("changePrecision: Unknown intrinsic call to change"); + } + } else { + StringRef funcName = calledFunc->getName(); + std::string newFuncName = getLibmFunctionForPrecision(funcName, newType); + + if (!newFuncName.empty()) { + Module *M = CI->getModule(); + SmallVector newArgTypes(newArgs.size(), newType); + + FunctionCallee newFuncCallee = M->getOrInsertFunction( + newFuncName, FunctionType::get(newType, newArgTypes, false)); + + if (Function *newFunc = dyn_cast(newFuncCallee.getCallee())) { + newI = Builder.CreateCall(newFunc, newArgs); + } else { + llvm::errs() << "PT: Failed to get " + << getPrecisionChangeTypeString(change.newType) + << " libm function for: " << *CI << "\n"; + llvm_unreachable("changePrecision: Failed to get libm function"); + } + } else { + llvm::errs() << "PT: Unknown function call: " << *CI << "\n"; + llvm_unreachable("changePrecision: Unknown function call to change"); + } + } + + } else { + llvm_unreachable("Unknown herbiable instruction"); + } + + oldToNew[I] = newI; +} + +struct PTCandidate { + // Only one PT candidate per FPCC can be applied + SmallVector changes; + double accuracyCost; + InstructionCost CompCost; + std::string desc; + std::unordered_map perOutputAccCost; + + // TODO: + explicit PTCandidate(SmallVector changes, + const std::string &desc) + : changes(std::move(changes)), desc(desc) {} + + // If `VMap` is passed, map `llvm::Value`s in `component` to their cloned + // values and change outputs in VMap to new casted outputs. + void apply(FPCC &component, ValueToValueMapTy *VMap = nullptr) { + SetVector operations; + ValueToValueMapTy clonedToOriginal; // Maps cloned outputs to old outputs + if (VMap) { + for (auto *I : component.operations) { + assert(VMap->count(I)); + operations.insert(cast(VMap->lookup(I))); + + clonedToOriginal[VMap->lookup(I)] = I; + // llvm::errs() << "Mapping back: " << *VMap->lookup(I) << " (in " + // << cast(VMap->lookup(I)) + // ->getParent() + // ->getParent() + // ->getName() + // << ") --> " << *I << " (in " + // << I->getParent()->getParent()->getName() << ")\n"; + } + } else { + operations = component.operations; + } + + for (auto &change : changes) { + SmallPtrSet seen; + SmallVector todo; + MapVector oldToNew; + + SetVector instsToChange; + for (auto node : change.nodes) { + if (!node || !node->value) { + continue; + } + assert(isa(node->value)); + auto *I = cast(node->value); + if (VMap) { + assert(VMap->count(I)); + I = cast(VMap->lookup(I)); + } + if (!operations.contains(I)) { + // Already erased by `AO.apply()`. + continue; + } + instsToChange.insert(I); + } + + SmallVector instsToChangeSorted; + topoSort(instsToChange, instsToChangeSorted); + + for (auto *I : instsToChangeSorted) { + changePrecision(I, change, oldToNew); + } + + // Restore the precisions of the last level of instructions to be changed. + // Clean up old instructions. + for (auto &[oldV, newV] : oldToNew) { + if (!isa(oldV)) { + continue; + } + + if (!instsToChange.contains(cast(oldV))) { + continue; + } + + SmallPtrSet users; + for (auto *user : oldV->users()) { + assert( + isa(user) && + "PT: Unexpected non-instruction user of a changed instruction"); + if (!instsToChange.contains(cast(user))) { + users.insert(cast(user)); + } + } + + Value *casted = nullptr; + if (!users.empty()) { + IRBuilder<> builder(cast(oldV)->getParent(), + ++BasicBlock::iterator(cast(oldV))); + casted = builder.CreateFPCast( + newV, getLLVMFPType(change.oldType, builder.getContext())); + + if (VMap) { + assert(VMap->count(clonedToOriginal[oldV])); + (*VMap)[clonedToOriginal[oldV]] = casted; + } + } + + for (auto *user : users) { + user->replaceUsesOfWith(oldV, casted); + } + + // Assumes no external uses of the old value since all corresponding new + // values are already restored to original precision and used to replace + // uses of their old value. This is also advantageous to the solvers. + for (auto *user : oldV->users()) { + assert(instsToChange.contains(cast(user)) && + "PT: Unexpected external user of a changed instruction"); + } + + if (!oldV->use_empty()) { + oldV->replaceAllUsesWith(UndefValue::get(oldV->getType())); + } + + cast(oldV)->eraseFromParent(); + + // The change is being materialized to the original component + if (!VMap) + component.operations.remove(cast(oldV)); + } + } + } +}; + +class FPEvaluator { + std::unordered_map cache; + std::unordered_map nodePrecisions; + +public: + FPEvaluator(PTCandidate *pt = nullptr) { + if (pt) { + for (const auto &change : pt->changes) { + for (auto node : change.nodes) { + nodePrecisions[node] = change.newType; + } + } + } + } + + PrecisionChangeType getNodePrecision(const FPNode *node) const { + // If the node has a new precision from PT, use it + PrecisionChangeType precType; + + auto it = nodePrecisions.find(node); + if (it != nodePrecisions.end()) { + precType = it->second; + } else { + // Otherwise, use the node's original precision + if (node->dtype == "f32") { + precType = PrecisionChangeType::FP32; + } else if (node->dtype == "f64") { + precType = PrecisionChangeType::FP64; + } else { + llvm_unreachable("Unsupported FP node precision type"); + } + } + + if (precType != PrecisionChangeType::FP32 && + precType != PrecisionChangeType::FP64) { + llvm_unreachable("Unsupported FP precision"); + } + + return precType; + } + + void evaluateNode(const FPNode *node, + const SmallMapVector &inputValues) { + if (cache.find(node) != cache.end()) + return; + + if (isa(node)) { + double constVal = node->getLowerBound(); // TODO: Can be improved + cache.emplace(node, constVal); + return; + } + + if (isa(node) && + inputValues.count(cast(node)->value)) { + double inputValue = inputValues.lookup(cast(node)->value); + cache.emplace(node, inputValue); + return; + } + + if (node->op == "if") { + evaluateNode(node->operands[0].get(), inputValues); + double cond = getResult(node->operands[0].get()); + + if (cond == 1.0) { + evaluateNode(node->operands[1].get(), inputValues); + double then_val = getResult(node->operands[1].get()); + cache.emplace(node, then_val); + } else { + evaluateNode(node->operands[2].get(), inputValues); + double else_val = getResult(node->operands[2].get()); + cache.emplace(node, else_val); + } + return; + } + + PrecisionChangeType nodePrec = getNodePrecision(node); + + for (const auto &operand : node->operands) { + evaluateNode(operand.get(), inputValues); + } + + double res = 0.0; + + if (node->op == "neg") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) ? -static_cast(op) + : -op; + } else if (node->op == "+") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) + static_cast(op1) + : op0 + op1; + } else if (node->op == "-") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) - static_cast(op1) + : op0 - op1; + } else if (node->op == "*") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) * static_cast(op1) + : op0 * op1; + } else if (node->op == "/") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) / static_cast(op1) + : op0 / op1; + } else if (node->op == "sin") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::sin(static_cast(op)) + : std::sin(op); + } else if (node->op == "cos") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::cos(static_cast(op)) + : std::cos(op); + } else if (node->op == "tan") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::tan(static_cast(op)) + : std::tan(op); + } else if (node->op == "exp") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::exp(static_cast(op)) + : std::exp(op); + } else if (node->op == "expm1") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::expm1(static_cast(op)) + : std::expm1(op); + } else if (node->op == "log") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::log(static_cast(op)) + : std::log(op); + } else if (node->op == "log1p") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::log1p(static_cast(op)) + : std::log1p(op); + } else if (node->op == "sqrt") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::sqrt(static_cast(op)) + : std::sqrt(op); + } else if (node->op == "cbrt") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::cbrt(static_cast(op)) + : std::cbrt(op); + } else if (node->op == "pow") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::pow(static_cast(op0), static_cast(op1)) + : std::pow(op0, op1); + } else if (node->op == "fabs") { + double op = getResult(node->operands[0].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::fabs(static_cast(op)) + : std::fabs(op); + } else if (node->op == "hypot") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::hypot(static_cast(op0), static_cast(op1)) + : std::hypot(op0, op1); + } else if (node->op == "fma") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + double op2 = getResult(node->operands[2].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? std::fma(static_cast(op0), static_cast(op1), + static_cast(op2)) + : std::fma(op0, op1, op2); + } else if (node->op == "==") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) == static_cast(op1) + : op0 == op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "!=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) != static_cast(op1) + : op0 != op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "<") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) < static_cast(op1) + : op0 < op1; + res = result ? 1.0 : 0.0; + } else if (node->op == ">") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) > static_cast(op1) + : op0 > op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "<=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) <= static_cast(op1) + : op0 <= op1; + res = result ? 1.0 : 0.0; + } else if (node->op == ">=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) >= static_cast(op1) + : op0 >= op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "and") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (op0 == 1.0 && op1 == 1.0) ? 1.0 : 0.0; + } else if (node->op == "or") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (op0 == 1.0 || op1 == 1.0) ? 1.0 : 0.0; + } else if (node->op == "not") { + double op = getResult(node->operands[0].get()); + res = (op == 1.0) ? 0.0 : 1.0; + } else if (node->op == "TRUE") { + res = 1.0; + } else if (node->op == "FALSE") { + res = 0.0; + } else { + std::string msg = "FPEvaluator: Unexpected operator " + node->op; + llvm_unreachable(msg.c_str()); + } + + cache.emplace(node, res); + } + + double getResult(const FPNode *node) const { + auto it = cache.find(node); + assert(it != cache.end() && "Node not evaluated yet"); + return it->second; + } +}; + +// Emulate computation using native floating-point types +void getFPValues(ArrayRef outputs, + const SmallMapVector &inputValues, + SmallVectorImpl &results, PTCandidate *pt = nullptr) { + assert(!outputs.empty()); + results.resize(outputs.size()); + + FPEvaluator evaluator(pt); + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = evaluator.getResult(outputs[i]); + } +} + +class MPFREvaluator { + struct CachedValue { + mpfr_t value; + unsigned prec; + + CachedValue(unsigned prec) : prec(prec) { + mpfr_init2(value, prec); + mpfr_set_zero(value, 1); + } + + CachedValue(const CachedValue &) = delete; + CachedValue &operator=(const CachedValue &) = delete; + + CachedValue(CachedValue &&other) noexcept : prec(other.prec) { + mpfr_init2(value, other.prec); + mpfr_swap(value, other.value); + } + + CachedValue &operator=(CachedValue &&other) noexcept { + if (this != &other) { + mpfr_set_prec(value, other.prec); + prec = other.prec; + mpfr_swap(value, other.value); + } + return *this; + } + + virtual ~CachedValue() { mpfr_clear(value); } + }; + + std::unordered_map cache; + unsigned prec; // Used only for ground truth evaluation + std::unordered_map nodeToNewPrec; + +public: + MPFREvaluator(unsigned prec, PTCandidate *pt = nullptr) : prec(prec) { + if (pt) { + for (const auto &change : pt->changes) { + for (auto node : change.nodes) { + nodeToNewPrec[node] = getMPFRPrec(change.newType); + } + } + } + } + + virtual ~MPFREvaluator() = default; + + unsigned getNodePrecision(const FPNode *node, bool groundTruth) const { + // If trying to evaluate the ground truth, use the current MPFR precision + if (groundTruth) + return prec; + + // If the node has a new precision for PT, use it + auto it = nodeToNewPrec.find(node); + if (it != nodeToNewPrec.end()) { + return it->second; + } + + // Otherwise, use the original precision + return node->getMPFRPrec(); + } + + // Compute the expression with MPFR at `prec` precision + // recursively. When operand is a FPConst, use its lower + // bound. When operand is a FPLLValue, get its inputs from + // `inputs`. + void evaluateNode(const FPNode *node, + const SmallMapVector &inputValues, + bool groundTruth) { + if (isa(node)) { + if (cache.find(node) != cache.end()) + return; + + double constVal = node->getLowerBound(); // TODO: Can be improved + CachedValue cv(53); + mpfr_set_d(cv.value, constVal, MPFR_RNDN); + + cache.emplace(node, std::move(cv)); + return; + } + + if (isa(node) && + inputValues.count(cast(node)->value)) { + if (cache.find(node) != cache.end()) + return; + + double inputValue = inputValues.lookup(cast(node)->value); + // llvm::errs() << "Input value for " << *cast(node)->value + // << ": " << inputValue << "\n"; + CachedValue cv(53); + mpfr_set_d(cv.value, inputValue, MPFR_RNDN); + + cache.emplace(node, std::move(cv)); + return; + } + + // Type of results of if nodes depend on the evaluated branches + if (node->op == "if") { + if (cache.find(node) != cache.end()) + return; + + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &cond = getResult(node->operands[0].get()); + + if (0 == mpfr_cmp_ui(cond, 1)) { + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &then_val = getResult(node->operands[1].get()); + cache.emplace(node, + CachedValue(cache.at(node->operands[1].get()).prec)); + mpfr_set(cache.at(node).value, then_val, MPFR_RNDN); + } else { + evaluateNode(node->operands[2].get(), inputValues, groundTruth); + mpfr_t &else_val = getResult(node->operands[2].get()); + cache.emplace(node, + CachedValue(cache.at(node->operands[2].get()).prec)); + mpfr_set(cache.at(node).value, else_val, MPFR_RNDN); + } + return; + } + + auto it = cache.find(node); + + unsigned nodePrec = getNodePrecision(node, groundTruth); + // llvm::errs() << "Precision for " << node->op << " set to " << nodePrec + // << "\n"; + + if (it != cache.end()) { + assert(cache.at(node).prec == nodePrec && "Unexpected precision change"); + return; + } else { + // Prepare for recomputation + cache.emplace(node, CachedValue(nodePrec)); + } + + mpfr_t &res = cache.at(node).value; + + // Cast operands to dest precision first -- this agrees with our + // precision tuner behavior, i.e., fpcasting operands first + if (node->op == "neg") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_neg(res, op, MPFR_RNDN); + } else if (node->op == "+") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_add(res, op0, op1, MPFR_RNDN); + } else if (node->op == "-") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_sub(res, op0, op1, MPFR_RNDN); + } else if (node->op == "*") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_mul(res, op0, op1, MPFR_RNDN); + } else if (node->op == "/") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_div(res, op0, op1, MPFR_RNDN); + } else if (node->op == "sin") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sin(res, op, MPFR_RNDN); + } else if (node->op == "cos") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cos(res, op, MPFR_RNDN); + } else if (node->op == "tan") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_tan(res, op, MPFR_RNDN); + } else if (node->op == "exp") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_exp(res, op, MPFR_RNDN); + } else if (node->op == "expm1") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_expm1(res, op, MPFR_RNDN); + } else if (node->op == "log") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log(res, op, MPFR_RNDN); + } else if (node->op == "log1p") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log1p(res, op, MPFR_RNDN); + } else if (node->op == "sqrt") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sqrt(res, op, MPFR_RNDN); + } else if (node->op == "cbrt") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cbrt(res, op, MPFR_RNDN); + } else if (node->op == "pow") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_pow(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fma") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + evaluateNode(node->operands[2].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_t &op2 = getResult(node->operands[2].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_prec_round(op2, nodePrec, MPFR_RNDN); + mpfr_fma(res, op0, op1, op2, MPFR_RNDN); + } else if (node->op == "fabs") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_abs(res, op, MPFR_RNDN); + } else if (node->op == "hypot") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_hypot(res, op0, op1, MPFR_RNDN); + } else if (node->op == "==") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 == mpfr_cmp(op0, op1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == "!=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 != mpfr_cmp(op0, op1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == "<") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 > mpfr_cmp(op0, op1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == ">") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 < mpfr_cmp(op0, op1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == "<=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 >= mpfr_cmp(op0, op1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == ">=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 <= mpfr_cmp(op0, op1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == "and") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 == mpfr_cmp_ui(op0, 1) && 0 == mpfr_cmp_ui(op1, 1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == "or") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + + if (0 == mpfr_cmp_ui(op0, 1) || 0 == mpfr_cmp_ui(op1, 1)) { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else { + mpfr_set_ui(res, 0, MPFR_RNDN); + } + } else if (node->op == "not") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_set_prec(res, nodePrec); + if (0 == mpfr_cmp_ui(op, 1)) { + mpfr_set_ui(res, 0, MPFR_RNDN); + } else { + mpfr_set_ui(res, 1, MPFR_RNDN); + } + } else if (node->op == "TRUE") { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else if (node->op == "FALSE") { + mpfr_set_ui(res, 0, MPFR_RNDN); + } else { + std::string msg = "MPFREvaluator: Unexpected operator " + node->op; + llvm_unreachable(msg.c_str()); + } + + // if (mpfr_nan_p(res)) { + // llvm::errs() << "WARNING MPFREvaluator: NaN detected for node " + // << node->op << "\n"; + // llvm::errs() << "Problematic operand(s):"; + // for (const auto &operand : node->operands) { + // mpfr_t &op = getResult(operand.get()); + // llvm::errs() << " " << mpfr_get_d(op, MPFR_RNDN); + // } + // llvm::errs() << "\n"; + // llvm::errs() << "Sampled input values:"; + // for (const auto &[_, input] : inputValues) { + // llvm::errs() << " " << input; + // } + // llvm::errs() << "\n"; + // } + } + + mpfr_t &getResult(FPNode *node) { + assert(cache.count(node) > 0 && + "MPFREvaluator: Unexpected unevaluated node"); + return cache.at(node).value; + } +}; + +// If looking for ground truth, compute a "correct" answer with MPFR. +// For each sampled input configuration: +// 0. Ignore `FPNode.dtype`. +// 1. Compute the expression with MPFR at `prec` precision +// by calling `MPFRValueHelper`. When operand is a FPConst, use its +// lower bound. When operand is a FPLLValue, get its inputs from +// `inputs`. +// 2. Dynamically extend precisions +// until the first `groundTruthPrec` bits of significand don't change. +// Otherwise, compute the expression with MPFR at precisions specified within +// `FPNode`s or new precisions specified by `pt`. +void getMPFRValues(ArrayRef outputs, + const SmallMapVector &inputValues, + SmallVectorImpl &results, bool groundTruth = false, + const unsigned groundTruthPrec = 53, + PTCandidate *pt = nullptr) { + assert(!outputs.empty()); + results.resize(outputs.size()); + + if (!groundTruth) { + MPFREvaluator evaluator(0, pt); + // if (pt) { + // llvm::errs() << "getMPFRValues: PT candidate detected: " << pt->desc + // << "\n"; + // } else { + // llvm::errs() << "getMPFRValues: emulating original computation\n"; + // } + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues, false); + } + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + return; + } + + unsigned curPrec = 64; + std::vector prevResExp(outputs.size(), 0); + std::vector prevResStr(outputs.size(), nullptr); + std::vector prevResSign(outputs.size(), 0); + std::vector converged(outputs.size(), false); + size_t numConverged = 0; + + while (true) { + MPFREvaluator evaluator(curPrec, nullptr); + + // llvm::errs() << "getMPFRValues: computing ground truth with precision " + // << curPrec << "\n"; + + for (const auto &output : outputs) { + evaluator.evaluateNode(output, inputValues, true); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + if (converged[i]) + continue; + + mpfr_t &res = evaluator.getResult(outputs[i]); + int resSign = mpfr_sgn(res); + mpfr_exp_t resExp; + char *resStr = + mpfr_get_str(nullptr, &resExp, 2, groundTruthPrec, res, MPFR_RNDN); + + if (prevResStr[i] != nullptr && resSign == prevResSign[i] && + resExp == prevResExp[i] && strcmp(resStr, prevResStr[i]) == 0) { + converged[i] = true; + numConverged++; + mpfr_free_str(resStr); + mpfr_free_str(prevResStr[i]); + prevResStr[i] = nullptr; + continue; + } + + if (prevResStr[i]) { + mpfr_free_str(prevResStr[i]); + } + prevResStr[i] = resStr; + prevResExp[i] = resExp; + prevResSign[i] = resSign; + } + + if (numConverged == outputs.size()) { + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + break; + } + + curPrec *= 2; + + if (curPrec > FPOptMaxMPFRPrec) { + llvm::errs() << "getMPFRValues: MPFR precision limit reached for some " + "outputs, returning NaN\n"; + for (size_t i = 0; i < outputs.size(); ++i) { + if (!converged[i]) { + mpfr_free_str(prevResStr[i]); + results[i] = std::numeric_limits::quiet_NaN(); + } else { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + } + return; + } + } +} + +void getUniqueArgs(const std::string &expr, SmallSet &args) { + // TODO: Update it if we use let expr in the future + std::regex argPattern("v\\d+"); + + std::sregex_iterator begin(expr.begin(), expr.end(), argPattern); + std::sregex_iterator end; + + while (begin != end) { + args.insert(begin->str()); + ++begin; + } +} + +void getSampledPoints( + ArrayRef inputs, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints) { + std::default_random_engine gen; + gen.seed(FPOptRandomSeed); + std::uniform_real_distribution<> dis; + + SmallMapVector, 4> hypercube; + for (const auto input : inputs) { + const auto node = valueToNodeMap.at(input); + + double lower = node->getLowerBound(); + double upper = node->getUpperBound(); + + hypercube.insert({input, {lower, upper}}); + } + + // llvm::errs() << "Hypercube:\n"; + // for (const auto &entry : hypercube) { + // Value *val = entry.first; + // double lower = entry.second[0]; + // double upper = entry.second[1]; + // llvm::errs() << valueToNodeMap.at(val)->symbol << ": [" << lower << ", " + // << upper << "]\n"; + // } + + // Sample `FPOptNumSamples` points from the hypercube. Store it in + // `sampledPoints`. + sampledPoints.clear(); + sampledPoints.resize(FPOptNumSamples); + for (int i = 0; i < FPOptNumSamples; ++i) { + SmallMapVector point; + for (const auto &entry : hypercube) { + Value *val = entry.first; + double lower = entry.second[0]; + double upper = entry.second[1]; + double sample = dis(gen, decltype(dis)::param_type{lower, upper}); + point.insert({val, sample}); + } + sampledPoints[i] = point; + // llvm::errs() << "Sample " << i << ":\n"; + // for (const auto &entry : point) { + // llvm::errs() << valueToNodeMap.at(entry.first)->symbol << ": " + // << entry.second << "\n"; + // } + } +} + +void getSampledPoints( + const std::string &expr, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints) { + SmallSet argStrSet; + getUniqueArgs(expr, argStrSet); + + SmallVector inputs; + for (const auto &argStr : argStrSet) { + inputs.push_back(symbolToValueMap.at(argStr)); + } + + getSampledPoints(inputs, valueToNodeMap, symbolToValueMap, sampledPoints); +} + +std::shared_ptr parseHerbieExpr( + const std::string &expr, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + // if (EnzymePrintFPOpt) + // llvm::errs() << "Parsing: " << expr << "\n"; + std::string trimmedExpr = expr; + trimmedExpr.erase(0, trimmedExpr.find_first_not_of(" ")); + trimmedExpr.erase(trimmedExpr.find_last_not_of(" ") + 1); + + // Arguments + if (trimmedExpr.front() != '(' && trimmedExpr.front() != '#') { + return valueToNodeMap[symbolToValueMap[trimmedExpr]]; + } + + // Constants + std::regex constantPattern( + "^#s\\(literal\\s+([-+]?\\d+(/\\d+)?|[-+]?inf\\.0)\\s+(\\w+)\\)$"); + + std::smatch matches; + if (std::regex_match(trimmedExpr, matches, constantPattern)) { + std::string value = matches[1].str(); + std::string dtype = matches[3].str(); + if (dtype == "binary64") { + dtype = "f64"; + } else if (dtype == "binary32") { + dtype = "f32"; + } else { + std::string msg = + "Herbie expr parser: Unexpected constant dtype: " + dtype; + llvm_unreachable(msg.c_str()); + } + // if (EnzymePrintFPOpt) + // llvm::errs() << "Herbie expr parser: Found __const " << value + // << " with dtype " << dtype << "\n"; + return std::make_shared(value, dtype); + } + + if (trimmedExpr.front() != '(' || trimmedExpr.back() != ')') { + llvm::errs() << "Unexpected subexpression: " << trimmedExpr << "\n"; + assert(0 && "Failed to parse Herbie expression"); + } + + trimmedExpr = trimmedExpr.substr(1, trimmedExpr.size() - 2); + + // Get the operator + auto endOp = trimmedExpr.find(' '); + std::string fullOp = trimmedExpr.substr(0, endOp); + + size_t pos = fullOp.find('.'); + + std::string dtype; + std::string op; + if (pos != std::string::npos) { + op = fullOp.substr(0, pos); + dtype = fullOp.substr(pos + 1); + assert(dtype == "f64" || dtype == "f32"); + // llvm::errs() << "Herbie expr parser: Found operator " << op + // << " with dtype " << dtype << "\n"; + } else { + op = fullOp; + // llvm::errs() << "Herbie expr parser: Found operator " << op << "\n"; + } + + auto node = std::make_shared(op, dtype); + + int depth = 0; + auto start = trimmedExpr.find_first_not_of(" ", endOp); + std::string::size_type curr; + for (curr = start; curr < trimmedExpr.size(); ++curr) { + if (trimmedExpr[curr] == '(') + depth++; + if (trimmedExpr[curr] == ')') + depth--; + if (depth == 0 && trimmedExpr[curr] == ' ') { + node->addOperand(parseHerbieExpr(trimmedExpr.substr(start, curr - start), + valueToNodeMap, symbolToValueMap)); + start = curr + 1; + } + } + if (start < curr) { + node->addOperand(parseHerbieExpr(trimmedExpr.substr(start, curr - start), + valueToNodeMap, symbolToValueMap)); + } + + return node; +} + +TargetTransformInfo::OperandValueKind getOperandValueKind(const Value *V) { + if (isa(V)) { + assert(!isa(V)); + return TargetTransformInfo::OK_UniformConstantValue; + } + return TargetTransformInfo::OK_AnyValue; +} + +TargetTransformInfo::OperandValueProperties +getOperandValueProperties(const Value *V) { + // TODO: Power of 2? + return TargetTransformInfo::OP_None; +} + +InstructionCost getInstructionCompCost(const Instruction *I, + const TargetTransformInfo &TTI) { + if (!FPOptCostModelPath.empty()) { + static std::map, InstructionCost> + CostModel; + static bool Loaded = false; + + if (!Loaded) { + std::ifstream CostFile(FPOptCostModelPath); + if (!CostFile.is_open()) { + std::string msg = + "Cost model file could not be opened: " + FPOptCostModelPath; + llvm_unreachable(msg.c_str()); + } + + std::string Line; + while (std::getline(CostFile, Line)) { + std::istringstream SS(Line); + std::string OpcodeStr, PrecisionStr; + std::string CostStr; + + if (!std::getline(SS, OpcodeStr, ',')) { + std::string msg = "Unexpected line in custom cost model: " + Line; + llvm_unreachable(msg.c_str()); + } + if (!std::getline(SS, PrecisionStr, ',')) { + std::string msg = "Unexpected line in custom cost model: " + Line; + llvm_unreachable(msg.c_str()); + } + if (!std::getline(SS, CostStr)) { + std::string msg = "Unexpected line in custom cost model: " + Line; + llvm_unreachable(msg.c_str()); + } + // llvm::errs() << "Cost model: " << OpcodeStr << ", " << PrecisionStr + // << ", " << CostStr << "\n"; + + CostModel[{OpcodeStr, PrecisionStr}] = std::stoi(CostStr); + } + + Loaded = true; + } + + std::string OpcodeName; + switch (I->getOpcode()) { + case Instruction::FNeg: + OpcodeName = "fneg"; + break; + case Instruction::FAdd: + OpcodeName = "fadd"; + break; + case Instruction::FSub: + OpcodeName = "fsub"; + break; + case Instruction::FMul: + OpcodeName = "fmul"; + break; + case Instruction::FDiv: + OpcodeName = "fdiv"; + break; + case Instruction::FCmp: + OpcodeName = "fcmp"; + break; + case Instruction::FPExt: + OpcodeName = "fpext"; + break; + case Instruction::FPTrunc: + OpcodeName = "fptrunc"; + break; + case Instruction::PHI: + return 0; + case Instruction::Call: { + auto *Call = cast(I); + if (auto CalledFunc = Call->getCalledFunction()) { + if (CalledFunc->isIntrinsic()) { + switch (CalledFunc->getIntrinsicID()) { + case Intrinsic::sin: + OpcodeName = "sin"; + break; + case Intrinsic::cos: + OpcodeName = "cos"; + break; + case Intrinsic::exp: + OpcodeName = "exp"; + break; + case Intrinsic::log: + OpcodeName = "log"; + break; + case Intrinsic::sqrt: + OpcodeName = "sqrt"; + break; + case Intrinsic::fabs: + OpcodeName = "fabs"; + break; + case Intrinsic::fma: + OpcodeName = "fma"; + break; + case Intrinsic::pow: + OpcodeName = "pow"; + break; + default: { + std::string msg = "Custom cost model: unsupported intrinsic " + + CalledFunc->getName().str(); + llvm_unreachable(msg.c_str()); + } + } + } else { + std::string FuncName = CalledFunc->getName().str(); + if (FuncName.back() == 'f' || FuncName.back() == 'l') { + FuncName.pop_back(); + } + + if (FuncName == "sin") { + OpcodeName = "sin"; + } else if (FuncName == "cos") { + OpcodeName = "cos"; + } else if (FuncName == "tan") { + OpcodeName = "tan"; + } else if (FuncName == "exp") { + OpcodeName = "exp"; + } else if (FuncName == "log") { + OpcodeName = "log"; + } else if (FuncName == "sqrt") { + OpcodeName = "sqrt"; + } else if (FuncName == "expm1") { + OpcodeName = "expm1"; + } else if (FuncName == "log1p") { + OpcodeName = "log1p"; + } else if (FuncName == "cbrt") { + OpcodeName = "cbrt"; + } else if (FuncName == "pow") { + OpcodeName = "pow"; + } else if (FuncName == "fabs") { + OpcodeName = "fabs"; + } else if (FuncName == "fma") { + OpcodeName = "fma"; + } else if (FuncName == "hypot") { + OpcodeName = "hypot"; + } else { + std::string msg = + "Custom cost model: unknown function call " + FuncName; + llvm_unreachable(msg.c_str()); + } + } + } else { + llvm_unreachable("Custom cost model: unknown function call"); + } + break; + } + default: + std::string msg = "Custom cost model: unexpected opcode " + + std::string(I->getOpcodeName()); + llvm_unreachable(msg.c_str()); + } + + std::string PrecisionName; + Type *Ty = I->getType(); + if (I->getOpcode() == Instruction::FCmp) { + Ty = I->getOperand(0)->getType(); + } + + if (Ty->isBFloatTy()) { + PrecisionName = "bf16"; + } else if (Ty->isHalfTy()) { + PrecisionName = "half"; + } else if (Ty->isFloatTy()) { + PrecisionName = "float"; + } else if (Ty->isDoubleTy()) { + PrecisionName = "double"; + } else if (Ty->isX86_FP80Ty()) { + PrecisionName = "fp80"; + } else if (Ty->isFP128Ty()) { + PrecisionName = "fp128"; + } else { + std::string msg = "Custom cost model: unsupported precision type!"; + llvm_unreachable(msg.c_str()); + } + + if (I->getOpcode() == Instruction::FPExt || + I->getOpcode() == Instruction::FPTrunc) { + Type *SrcTy = I->getOperand(0)->getType(); + std::string SrcPrecisionName; + + if (SrcTy->isBFloatTy()) { + SrcPrecisionName = "bf16"; + } else if (SrcTy->isHalfTy()) { + SrcPrecisionName = "half"; + } else if (SrcTy->isFloatTy()) { + SrcPrecisionName = "float"; + } else if (SrcTy->isDoubleTy()) { + SrcPrecisionName = "double"; + } else if (SrcTy->isX86_FP80Ty()) { + SrcPrecisionName = "fp80"; + } else if (SrcTy->isFP128Ty()) { + SrcPrecisionName = "fp128"; + } else { + std::string msg = "Custom cost model: unsupported precision type!"; + llvm_unreachable(msg.c_str()); + } + + OpcodeName += "_" + SrcPrecisionName + "_to_" + PrecisionName; + PrecisionName = SrcPrecisionName; + } + + auto Key = std::make_pair(OpcodeName, PrecisionName); + auto It = CostModel.find(Key); + if (It != CostModel.end()) { + return It->second; + } + + std::string msg = "Custom cost model: entry not found for " + OpcodeName + + " @ " + PrecisionName; + llvm::errs() << "Unexpected Intruction: " << *I << "\n"; + llvm_unreachable(msg.c_str()); + } + + llvm::errs() + << "IMPORTANT: Custom cost model not provided, using default cost!\n"; + + unsigned Opcode = I->getOpcode(); + switch (Opcode) { + case Instruction::FNeg: { + SmallVector Args(I->operands()); + return TTI.getArithmeticInstrCost( + Opcode, I->getType(), TargetTransformInfo::TCK_Latency, + getOperandValueKind(I->getOperand(0)), TargetTransformInfo::OK_AnyValue, + getOperandValueProperties(I->getOperand(0)), + TargetTransformInfo::OP_None, Args, I); + } + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: { + SmallVector Args(I->operands()); + return TTI.getArithmeticInstrCost( + Opcode, I->getType(), TargetTransformInfo::TCK_Latency, + getOperandValueKind(I->getOperand(0)), + getOperandValueKind(I->getOperand(1)), + getOperandValueProperties(I->getOperand(0)), + getOperandValueProperties(I->getOperand(1)), Args, I); + } + case Instruction::FCmp: { + const auto *FCI = cast(I); + return TTI.getCmpSelInstrCost(Opcode, FCI->getType(), /* CondTy */ nullptr, + FCI->getPredicate(), + TargetTransformInfo::TCK_Latency, I); + } + case Instruction::PHI: { + return TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency); + } + default: { + if (const auto *Call = dyn_cast(I)) { + if (Function *CalledFunc = Call->getCalledFunction()) { + if (CalledFunc->isIntrinsic()) { + auto IID = CalledFunc->getIntrinsicID(); + SmallVector OperandTypes; + SmallVector Args; + for (auto &Arg : Call->args()) { + OperandTypes.push_back(Arg->getType()); + Args.push_back(Arg.get()); + } + + IntrinsicCostAttributes ICA(IID, Call->getType(), Args, OperandTypes, + Call->getFastMathFlags(), + cast(I)); + return TTI.getIntrinsicInstrCost(ICA, + TargetTransformInfo::TCK_Latency); + } else { + SmallVector ArgTypes; + for (auto &Arg : Call->args()) + ArgTypes.push_back(Arg->getType()); + + return TTI.getCallInstrCost(CalledFunc, Call->getType(), ArgTypes, + TargetTransformInfo::TCK_Latency); + } + } + } + llvm::errs() << "WARNING: Using default cost for " << *I << "\n"; + return TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency); + } + } +} + +InstructionCost computeMaxCost( + BasicBlock *BB, std::unordered_map &MaxCost, + std::unordered_set &Visited, const TargetTransformInfo &TTI) { + if (MaxCost.find(BB) != MaxCost.end()) + return MaxCost[BB]; + + if (!Visited.insert(BB).second) + return 0; + + InstructionCost BBCost = 0; + for (const Instruction &I : *BB) { + if (I.isTerminator()) + continue; + + auto instCost = getInstructionCompCost(&I, TTI); + + // if (EnzymePrintFPOpt) + // llvm::errs() << "Cost of " << I << " is: " << instCost << "\n"; + + BBCost += instCost; + } + + InstructionCost succCost = 0; + + if (!succ_empty(BB)) { + InstructionCost maxSuccCost = 0; + for (BasicBlock *Succ : successors(BB)) { + InstructionCost succBBCost = computeMaxCost(Succ, MaxCost, Visited, TTI); + if (succBBCost > maxSuccCost) + maxSuccCost = succBBCost; + } + // llvm::errs() << "Max succ cost: " << maxSuccCost << "\n"; + succCost = maxSuccCost; + } + + InstructionCost totalCost = BBCost + succCost; + // llvm::errs() << "BB " << BB->getName() << " cost: " << totalCost << "\n"; + MaxCost[BB] = totalCost; + Visited.erase(BB); + return totalCost; +} + +InstructionCost getCompCost(Function *F, const TargetTransformInfo &TTI) { + std::unordered_map MaxCost; + std::unordered_set Visited; + + BasicBlock *EntryBB = &F->getEntryBlock(); + InstructionCost TotalCost = computeMaxCost(EntryBB, MaxCost, Visited, TTI); + // llvm::errs() << "Total cost: " << TotalCost << "\n"; + return TotalCost; +} + +// Sum up the cost of `output` and its FP operands recursively up to `inputs` +// (exclusive). +InstructionCost getCompCost(const SmallVector &outputs, + const SetVector &inputs, + const TargetTransformInfo &TTI) { + assert(!outputs.empty()); + SmallPtrSet seen; + SmallVector todo; + InstructionCost cost = 0; + + todo.insert(todo.end(), outputs.begin(), outputs.end()); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + if (!seen.insert(cur).second) + continue; + + if (inputs.contains(cur)) + continue; + + if (auto *I = dyn_cast(cur)) { + // TODO: unfair to ignore branches when calculating cost + auto instCost = getInstructionCompCost(I, TTI); + + // if (EnzymePrintFPOpt) + // llvm::errs() << "Cost of " << *I << " is: " << instCost << "\n"; + + // Only add the cost of the instruction if it is not an input + cost += instCost; + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &operand : operands) { + todo.push_back(operand); + } + } + } + + return cost; +} + +InstructionCost getCompCost( + const std::string &expr, Module *M, const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + const FastMathFlags &FMF) { + // llvm::errs() << "Evaluating cost of " << expr << "\n"; + SmallSet argStrSet; + getUniqueArgs(expr, argStrSet); + + SetVector args; + SmallVector argTypes; + SmallVector argNames; + for (const auto &argStr : argStrSet) { + Value *argValue = symbolToValueMap[argStr]; + args.insert(argValue); + argTypes.push_back(argValue->getType()); + argNames.push_back(argStr); + } + + auto parsedNode = parseHerbieExpr(expr, valueToNodeMap, symbolToValueMap); + + // Materialize the expression in a temporary function + FunctionType *FT = + FunctionType::get(Type::getVoidTy(M->getContext()), argTypes, false); + Function *tempFunction = + Function::Create(FT, Function::InternalLinkage, "tempFunc", M); + + ValueToValueMapTy VMap; + Function::arg_iterator AI = tempFunction->arg_begin(); + for (const auto &argStr : argNames) { + VMap[symbolToValueMap[argStr]] = &*AI; + ++AI; + } + + BasicBlock *entry = + BasicBlock::Create(M->getContext(), "entry", tempFunction); + Instruction *ReturnInst = ReturnInst::Create(M->getContext(), entry); + + IRBuilder<> builder(ReturnInst); + + builder.setFastMathFlags(FMF); + parsedNode->getLLValue(builder, &VMap); + + InstructionCost cost = getCompCost(tempFunction, TTI); + + tempFunction->eraseFromParent(); + return cost; +} + +InstructionCost getCompCost(FPCC &component, const TargetTransformInfo &TTI, + PTCandidate &pt) { + assert(!component.outputs.empty()); + + InstructionCost cost = 0; + + Function *F = cast(component.outputs[0])->getFunction(); + + ValueToValueMapTy VMap; + Function *FClone = CloneFunction(F, VMap); + FClone->setName(F->getName() + "_clone"); + + pt.apply(component, &VMap); + // output values in VMap are changed to the new casted values + // llvm::errs() << "\nDEBUG: " << pt.desc << "\n"; + // FClone->print(llvm::errs()); + + SmallPtrSet clonedInputs; + for (auto &input : component.inputs) { + clonedInputs.insert(VMap[input]); + } + + SmallPtrSet clonedOutputs; + for (auto &output : component.outputs) { + clonedOutputs.insert(VMap[output]); + } + + SmallPtrSet seen; + SmallVector todo; + + todo.insert(todo.end(), clonedOutputs.begin(), clonedOutputs.end()); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + if (!seen.insert(cur).second) + continue; + + if (clonedInputs.contains(cur)) + continue; + + if (auto *I = dyn_cast(cur)) { + auto instCost = getInstructionCompCost(I, TTI); + // llvm::errs() << "Cost of " << *I << " is: " << instCost << "\n"; + + cost += instCost; + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &operand : operands) { + todo.push_back(operand); + } + } + } + + FClone->eraseFromParent(); + + return cost; +} + +struct RewriteCandidate { + // Only one rewrite candidate per output `llvm::Value` can be applied + InstructionCost CompCost; + double herbieCost; // Unused for now + double herbieAccuracy; + double accuracyCost; + std::string expr; + + RewriteCandidate(double cost, double accuracy, std::string expression) + : herbieCost(cost), herbieAccuracy(accuracy), expr(expression) {} +}; + +void splitFPCC(FPCC &CC, SmallVector &newCCs) { + std::unordered_map shortestDistances; + + for (auto &op : CC.operations) { + shortestDistances[op] = std::numeric_limits::max(); + } + + // find the shortest distance from inputs to each operation + for (auto &input : CC.inputs) { + SmallVector, 8> todo; + for (auto user : input->users()) { + if (auto *I = dyn_cast(user); I && CC.operations.count(I)) { + todo.emplace_back(I, 1); + } + } + + while (!todo.empty()) { + auto [cur, dist] = todo.pop_back_val(); + if (dist < shortestDistances[cur]) { + shortestDistances[cur] = dist; + for (auto user : cur->users()) { + if (auto *I = dyn_cast(user); + I && CC.operations.count(I)) { + todo.emplace_back(I, dist + 1); + } + } + } + } + } + + // llvm::errs() << "Shortest distances:\n"; + // for (auto &[op, dist] : shortestDistances) { + // llvm::errs() << *op << ": " << dist << "\n"; + // } + + int maxDepth = + std::max_element(shortestDistances.begin(), shortestDistances.end(), + [](const auto &lhs, const auto &rhs) { + return lhs.second < rhs.second; + }) + ->second; + + if (maxDepth <= FPOptMaxFPCCDepth) { + newCCs.push_back(CC); + return; + } + + newCCs.resize(maxDepth / FPOptMaxFPCCDepth + 1); + + // Split `operations` based on the shortest distance + for (const auto &[op, dist] : shortestDistances) { + newCCs[dist / FPOptMaxFPCCDepth].operations.insert(op); + } + + // Reconstruct `inputs` and `outputs` for new components + for (auto &newCC : newCCs) { + for (auto &op : newCC.operations) { + auto operands = + isa(op) ? cast(op)->args() : op->operands(); + for (auto &operand : operands) { + if (newCC.inputs.count(operand)) { + continue; + } + + // Original non-herbiable operands or herbiable intermediate operations + if (CC.inputs.count(operand) || + !newCC.operations.count(cast(operand))) { + newCC.inputs.insert(operand); + } + } + + for (auto user : op->users()) { + if (auto *I = dyn_cast(user); + I && !newCC.operations.count(I)) { + newCC.outputs.insert(op); + } + } + } + } + + if (EnzymePrintFPOpt) { + llvm::errs() << "Splitting the FPCC into " << newCCs.size() + << " components\n"; + } +} + +void collectExprInsts(Value *V, const SetVector &inputs, + SmallPtrSetImpl &exprInsts, + SmallPtrSetImpl &visited) { + if (!V || inputs.contains(V) || visited.contains(V)) { + return; + } + + visited.insert(V); + + if (auto *I = dyn_cast(V)) { + exprInsts.insert(I); + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + + for (auto &op : operands) { + collectExprInsts(op, inputs, exprInsts, visited); + } + } +} + +class ApplicableOutput; +class ApplicableFPCC; + +struct SolutionStep { + std::variant item; + size_t candidateIndex; + + SolutionStep(ApplicableOutput *ao_, size_t idx) + : item(ao_), candidateIndex(idx) {} + + SolutionStep(ApplicableFPCC *acc_, size_t idx) + : item(acc_), candidateIndex(idx) {} +}; + +class ApplicableOutput { +public: + FPCC *component; + Value *oldOutput; + std::string expr; + double grad; + unsigned executions; + const TargetTransformInfo *TTI; + double initialAccCost; // Requires manual initialization + InstructionCost initialCompCost; // Requires manual initialization + double initialHerbieCost; // Requires manual initialization + double initialHerbieAccuracy; // Requires manual initialization + SmallVector candidates; + SmallPtrSet erasableInsts; + + explicit ApplicableOutput(FPCC &component, Value *oldOutput, std::string expr, + double grad, unsigned executions, + const TargetTransformInfo &TTI) + : component(&component), oldOutput(oldOutput), expr(expr), grad(grad), + executions(executions), TTI(&TTI) { + initialCompCost = getCompCost({oldOutput}, component.inputs, TTI); + findErasableInstructions(); + } + + void + apply(size_t candidateIndex, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + // 4) parse the output string solution from herbieland + // 5) convert into a solution in llvm vals/instructions + + // if (EnzymePrintFPOpt) + // llvm::errs() << "Parsing Herbie output: " << herbieOutput << "\n"; + auto parsedNode = parseHerbieExpr(candidates[candidateIndex].expr, + valueToNodeMap, symbolToValueMap); + // if (EnzymePrintFPOpt) + // llvm::errs() << "Parsed Herbie output: " + // << parsedNode->toFullExpression(valueToNodeMap) << "\n"; + + IRBuilder<> builder(cast(oldOutput)->getParent(), + ++BasicBlock::iterator(cast(oldOutput))); + builder.setFastMathFlags(cast(oldOutput)->getFastMathFlags()); + + // auto *F = cast(oldOutput)->getParent()->getParent(); + // llvm::errs() << "Before: " << *F << "\n"; + Value *newOutput = parsedNode->getLLValue(builder); + assert(newOutput && "Failed to get value from parsed node"); + + oldOutput->replaceAllUsesWith(newOutput); + symbolToValueMap[valueToNodeMap[oldOutput]->symbol] = newOutput; + valueToNodeMap[newOutput] = std::make_shared( + newOutput, "__no", valueToNodeMap[oldOutput]->dtype); + + for (auto *I : erasableInsts) { + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + component->operations.remove(I); // Avoid a second removal + cast(valueToNodeMap[I].get())->value = nullptr; + } + + // llvm::errs() << "After: " << *F << "\n"; + + component->outputs_rewritten++; + } + + // Lower is better + InstructionCost getCompCostDelta(size_t candidateIndex) { + InstructionCost erasableCost = 0; + + for (auto *I : erasableInsts) { + erasableCost += getInstructionCompCost(I, *TTI); + } + + return (candidates[candidateIndex].CompCost - erasableCost) * executions; + } + + // Lower is better + double getAccCostDelta(size_t candidateIndex) { + return candidates[candidateIndex].accuracyCost - initialAccCost; + } + + void findErasableInstructions() { + SmallPtrSet visited; + SmallPtrSet exprInsts; + collectExprInsts(oldOutput, component->inputs, exprInsts, visited); + visited.clear(); + + SetVector instsToProcess(exprInsts.begin(), exprInsts.end()); + + SmallVector instsToProcessSorted; + topoSort(instsToProcess, instsToProcessSorted); + + // `oldOutput` is trivially erasable + erasableInsts.clear(); + erasableInsts.insert(cast(oldOutput)); + + for (auto *I : reverse(instsToProcessSorted)) { + if (erasableInsts.contains(I)) + continue; + + bool usedOutside = false; + for (auto user : I->users()) { + if (auto *userI = dyn_cast(user)) { + if (erasableInsts.contains(userI)) { + continue; + } + } + // If the user is not an intruction or the user instruction is not an + // erasable instruction, then the current instruction is not erasable + // llvm::errs() << "Can't erase " << *I << " because of " << *user << + // "\n"; + usedOutside = true; + break; + } + + if (!usedOutside) { + erasableInsts.insert(I); + } + } + + // llvm::errs() << "Erasable instructions:\n"; + // for (auto *I : erasableInsts) { + // llvm::errs() << *I << "\n"; + // } + // llvm::errs() << "End of erasable instructions\n"; + } +}; + +void setUnifiedAccuracyCost( + ApplicableFPCC &ACC, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +class ApplicableFPCC { +public: + FPCC *component; + const TargetTransformInfo &TTI; + double initialAccCost; // Requires manual initialization + InstructionCost initialCompCost; + unsigned executions; // Requires manual initialization + std::unordered_map perOutputInitialAccCost; + + SmallVector candidates; + + // Caches for adjusted cost calculations + using ApplicableOutputSet = std::set; + struct CacheKey { + size_t candidateIndex; + ApplicableOutputSet applicableOutputs; + + bool operator==(const CacheKey &other) const { + return candidateIndex == other.candidateIndex && + applicableOutputs == other.applicableOutputs; + } + }; + + struct CacheKeyHash { + std::size_t operator()(const CacheKey &key) const { + std::size_t seed = std::hash{}(key.candidateIndex); + for (const auto *ao : key.applicableOutputs) { + seed ^= std::hash{}(ao) + 0x9e3779b9 + + (seed << 6) + (seed >> 2); + } + return seed; + } + }; + + std::unordered_map + compCostDeltaCache; + std::unordered_map accCostDeltaCache; + + explicit ApplicableFPCC(FPCC &fpcc, const TargetTransformInfo &TTI) + : component(&fpcc), TTI(TTI) { + initialCompCost = + getCompCost({component->outputs.begin(), component->outputs.end()}, + component->inputs, TTI); + } + + void apply(size_t candidateIndex) { + if (candidateIndex >= candidates.size()) { + llvm_unreachable("Invalid candidate index"); + } + + // Traverse all the instructions to be changed precisions in a + // topological order with respect to operand dependencies. Insert FP casts + // between llvm::Value inputs and first level of instructions to be changed. + // Restore precisions of the last level of instructions to be changed. + candidates[candidateIndex].apply(*component); + } + + // Lower is better + InstructionCost getCompCostDelta(size_t candidateIndex) { + // TODO: adjust this based on erasured instructions + return (candidates[candidateIndex].CompCost - initialCompCost) * executions; + } + + // Lower is better + double getAccCostDelta(size_t candidateIndex) { + return candidates[candidateIndex].accuracyCost - initialAccCost; + } + + InstructionCost + getAdjustedCompCostDelta(size_t candidateIndex, + const SmallVectorImpl &steps) { + ApplicableOutputSet applicableOutputs; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + if ((*ptr)->component == component) { + applicableOutputs.insert(*ptr); + } + } + } + + CacheKey key{candidateIndex, applicableOutputs}; + + auto cacheIt = compCostDeltaCache.find(key); + if (cacheIt != compCostDeltaCache.end()) { + return cacheIt->second; + } + + FPCC newComponent = *this->component; + + for (auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + const auto &AO = **ptr; + if (AO.component == component) { + // Eliminate erasadable instructions from the adjusted ACC + newComponent.operations.remove_if( + [&AO](Instruction *I) { return AO.erasableInsts.contains(I); }); + newComponent.outputs.remove(cast(AO.oldOutput)); + } + } + } + + // If all outputs are rewritten, then the adjusted ACC is empty + if (newComponent.outputs.empty()) { + compCostDeltaCache[key] = 0; + return 0; + } + + InstructionCost initialCompCost = + getCompCost({newComponent.outputs.begin(), newComponent.outputs.end()}, + newComponent.inputs, TTI); + + InstructionCost candidateCompCost = + getCompCost(newComponent, TTI, candidates[candidateIndex]); + + InstructionCost adjustedCostDelta = + (candidateCompCost - initialCompCost) * executions; + // llvm::errs() << "Initial cost: " << initialCompCost << "\n"; + // llvm::errs() << "Candidate cost: " << candidateCompCost << "\n"; + // llvm::errs() << "Num executions: " << executions << "\n"; + // llvm::errs() << "Adjusted cost delta: " << adjustedCostDelta << "\n\n"; + + compCostDeltaCache[key] = adjustedCostDelta; + return adjustedCostDelta; + } + + double getAdjustedAccCostDelta( + size_t candidateIndex, SmallVectorImpl &steps, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + ApplicableOutputSet applicableOutputs; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + if ((*ptr)->component == component) { + applicableOutputs.insert(*ptr); + } + } + } + + CacheKey key{candidateIndex, applicableOutputs}; + + auto cacheIt = accCostDeltaCache.find(key); + if (cacheIt != accCostDeltaCache.end()) { + return cacheIt->second; + } + + double totalCandidateAccCost = 0.0; + double totalInitialAccCost = 0.0; + + // Collect erased output nodes + SmallPtrSet stepNodes; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + const auto &AO = **ptr; + if (AO.component == component) { + auto it = valueToNodeMap.find(AO.oldOutput); + assert(it != valueToNodeMap.end() && it->second); + stepNodes.insert(it->second.get()); + } + } + } + + // Iterate over all output nodes and sum costs for nodes not erased + for (auto &[node, cost] : perOutputInitialAccCost) { + if (!stepNodes.count(node)) { + totalInitialAccCost += cost; + } + } + + for (auto &[node, cost] : candidates[candidateIndex].perOutputAccCost) { + if (!stepNodes.count(node)) { + totalCandidateAccCost += cost; + } + } + + double adjustedAccCostDelta = totalCandidateAccCost - totalInitialAccCost; + + accCostDeltaCache[key] = adjustedAccCostDelta; + return adjustedAccCostDelta; + } +}; + +void setUnifiedAccuracyCost( + ApplicableOutput &AO, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + + SmallVector, 4> sampledPoints; + getSampledPoints(AO.component->inputs.getArrayRef(), valueToNodeMap, + symbolToValueMap, sampledPoints); + + SmallVector goldVals; + goldVals.resize(FPOptNumSamples); + double initAC = 0.; + + unsigned numValidSamples = 0; + for (const auto &pair : enumerate(sampledPoints)) { + ArrayRef outputs = {valueToNodeMap[AO.oldOutput].get()}; + SmallVector results; + getMPFRValues(outputs, pair.value(), results, true, 53); + double goldVal = results[0]; + // llvm::errs() << "DEBUG AO gold value: " << goldVal << "\n"; + goldVals[pair.index()] = goldVal; + + getFPValues(outputs, pair.value(), results); + double realVal = results[0]; + // llvm::errs() << "DEBUG AO real value: " << realVal << "\n"; + + if (!std::isnan(goldVal) && !std::isnan(realVal)) { + initAC += std::log1p(std::fabs(goldVal - realVal)); + numValidSamples++; + } + } + + AO.initialAccCost = std::expm1(initAC / numValidSamples) * std::fabs(AO.grad); + // llvm::errs() << "DEBUG calculated AO initial accuracy cost: " + // << AO.initialAccCost << "\n"; + assert(numValidSamples && "No valid samples for AO -- try increasing the " + "number of samples"); + assert(!std::isnan(AO.initialAccCost)); + + for (auto &candidate : AO.candidates) { + const auto &expr = candidate.expr; + auto parsedNode = parseHerbieExpr(expr, valueToNodeMap, symbolToValueMap); + double ac = 0.; + + numValidSamples = 0; + for (const auto &pair : enumerate(sampledPoints)) { + // Compute the "gold" value & real value for each sampled point + // Compute an average of (difference * gradient) + // TODO: Consider geometric average??? + assert(valueToNodeMap.count(AO.oldOutput)); + + // llvm::errs() << "Computing real output for candidate: " << expr << + // "\n"; + + // llvm::errs() << "Current input values:\n"; + // for (const auto &entry : pair.value()) { + // llvm::errs() << valueToNodeMap[entry.first]->symbol << ": " + // << entry.second << "\n"; + // } + + // llvm::errs() << "Gold value: " << goldVals[pair.index()] << "\n"; + + ArrayRef outputs = {parsedNode.get()}; + SmallVector results; + getFPValues(outputs, pair.value(), results); + double realVal = results[0]; + + // llvm::errs() << "Real value: " << realVal << "\n"; + double goldVal = goldVals[pair.index()]; + if (!std::isnan(goldVal) && !std::isnan(realVal)) { + ac += std::log1p(std::fabs(goldVal - realVal)); + numValidSamples++; + } + } + assert(numValidSamples && "No valid samples for AO -- try increasing the " + "number of samples"); + candidate.accuracyCost = + std::expm1(ac / numValidSamples) * std::fabs(AO.grad); + assert(!std::isnan(candidate.accuracyCost)); + } +} + +void setUnifiedAccuracyCost( + ApplicableFPCC &ACC, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + + SmallVector, 4> sampledPoints; + getSampledPoints(ACC.component->inputs.getArrayRef(), valueToNodeMap, + symbolToValueMap, sampledPoints); + + SmallMapVector, 4> + goldVals; // output -> gold vals + for (auto *output : ACC.component->outputs) { + auto *node = valueToNodeMap[output].get(); + goldVals[node].resize(FPOptNumSamples); + ACC.perOutputInitialAccCost[node] = 0.; + } + + SmallVector outputs; + for (auto *output : ACC.component->outputs) { + outputs.push_back(valueToNodeMap[output].get()); + } + + std::unordered_map numValidSamplesPerOutput; + for (auto *output : outputs) { + numValidSamplesPerOutput[output] = 0; + } + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + + // Get ground truth values for all outputs + getMPFRValues(outputs, pair.value(), results, true, 53); + for (const auto &[output, result] : zip(outputs, results)) { + goldVals[output][pair.index()] = result; + // llvm::errs() << "DEBUG ACC gold value: " << result << "\n"; + } + + // Emulate FPCC with parsed precision + getFPValues(outputs, pair.value(), results); + + for (const auto &[output, result] : zip(outputs, results)) { + // llvm::errs() << "DEBUG ACC real value: " << result << "\n"; + double goldVal = goldVals[output][pair.index()]; + if (!std::isnan(goldVal) && !std::isnan(result)) { + double diff = std::fabs(goldVal - result); + ACC.perOutputInitialAccCost[output] += std::log1p(diff); + numValidSamplesPerOutput[output]++; + } + } + } + + // Normalize accuracy costs and compute aggregated initialAccCost + ACC.initialAccCost = 0.0; + for (auto *output : outputs) { + unsigned numValidSamples = numValidSamplesPerOutput[output]; + assert(numValidSamples && "No valid samples for at least one output node " + "-- try increasing the number of samples"); + // Local error --> global error + ACC.perOutputInitialAccCost[output] = + std::expm1(ACC.perOutputInitialAccCost[output] / numValidSamples) * + std::fabs(output->grad); + // llvm::errs() << "DEBUG calculated ACC per output initial accuracy cost: " + // << ACC.perOutputInitialAccCost[output] << "\n"; + ACC.initialAccCost += ACC.perOutputInitialAccCost[output]; + } + assert(!std::isnan(ACC.initialAccCost)); + + // Compute accuracy costs for each PT candidate + for (auto &candidate : ACC.candidates) { + std::unordered_map numValidSamplesPerOutput; + for (auto *output : outputs) { + candidate.perOutputAccCost[output] = 0.; + numValidSamplesPerOutput[output] = 0; + } + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues(outputs, pair.value(), results, &candidate); + + for (const auto &[output, result] : zip(outputs, results)) { + double goldVal = goldVals[output][pair.index()]; + if (!std::isnan(goldVal) && !std::isnan(result)) { + double diff = std::fabs(goldVal - result); + // Sum up local errors + candidate.perOutputAccCost[output] += std::log1p(diff); + numValidSamplesPerOutput[output]++; + } + } + } + + // Normalize accuracy costs and compute aggregated accuracyCost + candidate.accuracyCost = 0.0; + for (auto *output : outputs) { + unsigned numValidSamples = numValidSamplesPerOutput[output]; + assert(numValidSamples && "No valid samples for output -- try increasing " + "the number of samples"); + // Local error --> global error + candidate.perOutputAccCost[output] = + std::expm1(candidate.perOutputAccCost[output] / numValidSamples) * + std::fabs(output->grad); + // llvm::errs() + // << "DEBUG calculated ACC per output candidate accuracy cost: " + // << candidate.perOutputAccCost[output] << "\n"; + candidate.accuracyCost += candidate.perOutputAccCost[output]; + } + assert(!std::isnan(candidate.accuracyCost)); + } +} + +bool improveViaHerbie( + const std::vector &inputExprs, + std::vector &AOs, Module *M, + const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + int componentIndex) { + std::string Program = HERBIE_BINARY; + llvm::errs() << "random seed: " << std::to_string(FPOptRandomSeed) << "\n"; + + SmallVector BaseArgs = { + Program, "report", + "--seed", std::to_string(FPOptRandomSeed), + "--timeout", std::to_string(HerbieTimeout), + "--threads", std::to_string(HerbieNumThreads), + "--num-points", std::to_string(HerbieNumPoints), + "--num-iters", std::to_string(HerbieNumIters)}; + + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:proofs"); + + if (HerbieDisableNumerics) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("rules:numerics"); + } + + if (HerbieDisableSetupSimplify) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("setup:simplify"); + } + + if (HerbieDisableGenSimplify) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:simplify"); + } + + if (HerbieDisableTaylor) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:taylor"); + } + + if (HerbieDisableRegime) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:regimes"); + } + + if (HerbieDisableBranchExpr) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:branch-expressions"); + } + + if (HerbieDisableAvgError) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:avg-error"); + } + + SmallVector> BaseArgsList; + + if (!HerbieDisableTaylor) { + SmallVector Args1 = BaseArgs; + BaseArgsList.push_back(Args1); + + SmallVector Args2 = BaseArgs; + Args2.push_back("--disable"); + Args2.push_back("generate:taylor"); + BaseArgsList.push_back(Args2); + } else { + BaseArgsList.push_back(BaseArgs); + } + + std::vector> seenExprs(AOs.size()); + + bool success = false; + + for (size_t baseArgsIndex = 0; baseArgsIndex < BaseArgsList.size(); + ++baseArgsIndex) { + const auto &BaseArgs = BaseArgsList[baseArgsIndex]; + std::string content; + bool cached = false; + std::string cacheFilePath; + + if (!FPOptCachePath.empty()) { + cacheFilePath = FPOptCachePath + "/cachedHerbieOutput_" + + std::to_string(componentIndex) + "_" + + std::to_string(baseArgsIndex) + ".txt"; + std::ifstream cacheFile(cacheFilePath); + if (cacheFile) { + content.assign((std::istreambuf_iterator(cacheFile)), + std::istreambuf_iterator()); + cacheFile.close(); + llvm::errs() << "Using cached Herbie output from " << cacheFilePath + << "\n"; + cached = true; + } + } + + if (cached) { + llvm::errs() << "Herbie output: " << content << "\n"; + + Expected parsed = json::parse(content); + if (!parsed) { + llvm::errs() << "Failed to parse Herbie result!\n"; + continue; + } + + json::Object *obj = parsed->getAsObject(); + json::Array &tests = *obj->getArray("tests"); + + for (size_t testIndex = 0; testIndex < tests.size(); ++testIndex) { + auto &test = *tests[testIndex].getAsObject(); + + StringRef bestExpr = test.getString("output").getValue(); + StringRef ID = test.getString("name").getValue(); + + if (bestExpr == "#f") { + continue; + } + + int index = std::stoi(ID.str()); + if (index >= AOs.size()) { + llvm::errs() << "Invalid AO index: " << index << "\n"; + continue; + } + + ApplicableOutput &AO = AOs[index]; + auto &seenExprSet = seenExprs[index]; + + double bits = test.getNumber("bits").getValue(); + json::Array &costAccuracy = *test.getArray("cost-accuracy"); + + json::Array &initial = *costAccuracy[0].getAsArray(); + double initialCostVal = initial[0].getAsNumber().getValue(); + double initialCost = 1.0; + double initialAccuracy = + 1.0 - initial[1].getAsNumber().getValue() / bits; + + AO.initialHerbieCost = initialCost; + AO.initialHerbieAccuracy = initialAccuracy; + + if (seenExprSet.count(bestExpr.str()) == 0) { + seenExprSet.insert(bestExpr.str()); + + json::Array &best = *costAccuracy[1].getAsArray(); + double bestCost = best[0].getAsNumber().getValue() / initialCostVal; + double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits; + + RewriteCandidate bestCandidate(bestCost, bestAccuracy, + bestExpr.str()); + bestCandidate.CompCost = getCompCost( + bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags()); + AO.candidates.push_back(bestCandidate); + } + + json::Array &alternatives = *costAccuracy[2].getAsArray(); + + // Handle alternatives + for (size_t j = 0; j < alternatives.size(); ++j) { + json::Array &entry = *alternatives[j].getAsArray(); + StringRef expr = entry[2].getAsString().getValue(); + + if (seenExprSet.count(expr.str()) != 0) { + continue; + } + seenExprSet.insert(expr.str()); + + double cost = entry[0].getAsNumber().getValue() / initialCostVal; + double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits; + + RewriteCandidate candidate(cost, accuracy, expr.str()); + candidate.CompCost = + getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags()); + AO.candidates.push_back(candidate); + } + + setUnifiedAccuracyCost(AO, valueToNodeMap, symbolToValueMap); + + success = true; + } + + continue; + } + + SmallString<32> tmpin, tmpout; + + if (llvm::sys::fs::createUniqueFile("herbie_input_%%%%%%%%%%%%%%%%", tmpin, + llvm::sys::fs::perms::owner_all)) { + llvm::errs() << "Failed to create a unique input file.\n"; + continue; + } + + if (llvm::sys::fs::createUniqueDirectory("herbie_output_%%%%%%%%%%%%%%%%", + tmpout)) { + llvm::errs() << "Failed to create a unique output directory.\n"; + llvm::sys::fs::remove(tmpin); + continue; + } + + std::ofstream input(tmpin.c_str()); + if (!input) { + llvm::errs() << "Failed to open input file.\n"; + llvm::sys::fs::remove(tmpin); + llvm::sys::fs::remove(tmpout); + continue; + } + for (const auto &expr : inputExprs) { + input << expr << "\n"; + } + input.close(); + + SmallVector Args = BaseArgs; + Args.push_back(tmpin); + Args.push_back(tmpout); + + std::string ErrMsg; + bool ExecutionFailed = false; + + if (EnzymePrintFPOpt) { + llvm::errs() << "Executing Herbie with arguments: "; + for (const auto &arg : Args) { + llvm::errs() << arg << " "; + } + llvm::errs() << "\n"; + } + + llvm::sys::ExecuteAndWait(Program, Args, /*Env=*/llvm::None, + /*Redirects=*/llvm::None, + /*SecondsToWait=*/0, /*MemoryLimit=*/0, &ErrMsg, + &ExecutionFailed); + + std::remove(tmpin.c_str()); + if (ExecutionFailed) { + llvm::errs() << "Execution failed: " << ErrMsg << "\n"; + llvm::sys::fs::remove(tmpout); + continue; + } + + std::ifstream output((tmpout + "/results.json").str()); + if (!output) { + llvm::errs() << "Failed to open output file.\n"; + llvm::sys::fs::remove(tmpout); + continue; + } + content.assign((std::istreambuf_iterator(output)), + std::istreambuf_iterator()); + output.close(); + llvm::sys::fs::remove(tmpout.c_str()); + + llvm::errs() << "Herbie output: " << content << "\n"; + + if (!FPOptCachePath.empty()) { + llvm::sys::fs::create_directories(FPOptCachePath, true); + std::ofstream cacheFile(cacheFilePath); + if (!cacheFile) { + llvm_unreachable("Failed to open cache file for writing"); + } else { + cacheFile << content; + cacheFile.close(); + llvm::errs() << "Saved Herbie output to cache file " << cacheFilePath + << "\n"; + } + } + + Expected parsed = json::parse(content); + if (!parsed) { + llvm::errs() << "Failed to parse Herbie result!\n"; + continue; + } + + json::Object *obj = parsed->getAsObject(); + json::Array &tests = *obj->getArray("tests"); + + for (size_t testIndex = 0; testIndex < tests.size(); ++testIndex) { + auto &test = *tests[testIndex].getAsObject(); + + StringRef bestExpr = test.getString("output").getValue(); + + if (bestExpr == "#f") { + continue; + } + + StringRef ID = test.getString("name").getValue(); + int index = std::stoi(ID.str()); + if (index >= AOs.size()) { + llvm::errs() << "Invalid AO index: " << index << "\n"; + continue; + } + + ApplicableOutput &AO = AOs[index]; + auto &seenExprSet = seenExprs[index]; + + double bits = test.getNumber("bits").getValue(); + json::Array &costAccuracy = *test.getArray("cost-accuracy"); + + json::Array &initial = *costAccuracy[0].getAsArray(); + double initialCostVal = initial[0].getAsNumber().getValue(); + double initialCost = 1.0; + double initialAccuracy = 1.0 - initial[1].getAsNumber().getValue() / bits; + + AO.initialHerbieCost = initialCost; + AO.initialHerbieAccuracy = initialAccuracy; + + if (seenExprSet.count(bestExpr.str()) == 0) { + seenExprSet.insert(bestExpr.str()); + + json::Array &best = *costAccuracy[1].getAsArray(); + double bestCost = best[0].getAsNumber().getValue() / initialCostVal; + double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits; + + RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str()); + bestCandidate.CompCost = getCompCost( + bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags()); + AO.candidates.push_back(bestCandidate); + } + + json::Array &alternatives = *costAccuracy[2].getAsArray(); + + // Handle alternatives + for (size_t j = 0; j < alternatives.size(); ++j) { + json::Array &entry = *alternatives[j].getAsArray(); + StringRef expr = entry[2].getAsString().getValue(); + + if (seenExprSet.count(expr.str()) != 0) { + continue; + } + seenExprSet.insert(expr.str()); + + double cost = entry[0].getAsNumber().getValue() / initialCostVal; + double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits; + + RewriteCandidate candidate(cost, accuracy, expr.str()); + candidate.CompCost = + getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags()); + AO.candidates.push_back(candidate); + } + + setUnifiedAccuracyCost(AO, valueToNodeMap, symbolToValueMap); + + success = true; + } + } + + return success; +} + +std::string getHerbieOperator(const Instruction &I) { + switch (I.getOpcode()) { + case Instruction::FNeg: + return "neg"; + case Instruction::FAdd: + return "+"; + case Instruction::FSub: + return "-"; + case Instruction::FMul: + return "*"; + case Instruction::FDiv: + return "/"; + case Instruction::Call: { + const CallInst *CI = dyn_cast(&I); + assert(CI && CI->getCalledFunction() && + "getHerbieOperator: Call without a function"); + + std::string funcName = CI->getCalledFunction()->getName().str(); + + // Special cases + if (startsWith(funcName, "cbrt")) + return "cbrt"; + + std::regex regex("llvm\\.(\\w+)\\.?.*"); + std::smatch matches; + if (std::regex_search(funcName, matches, regex) && matches.size() > 1) { + if (matches[1].str() == "fmuladd") + return "fma"; + return matches[1].str(); + } + assert(0 && "getHerbieOperator: Unknown callee"); + } + default: + assert(0 && "getHerbieOperator: Unknown operator"); + } +} + +struct ValueInfo { + double minRes; + double maxRes; + unsigned executions; + double geometricAvg; + SmallVector lower; + SmallVector upper; +}; + +bool extractValueFromLog(const std::string &logPath, + const std::string &functionName, size_t blockIdx, + size_t instIdx, ValueInfo &data) { + std::ifstream file(logPath); + if (!file.is_open()) { + llvm_unreachable("Failed to open log file"); + } + + std::string line; + std::regex valuePattern("^Value:" + functionName + ":" + + std::to_string(blockIdx) + ":" + + std::to_string(instIdx) + "$"); + std::regex newEntryPattern("^(Value|Grad):"); + + while (getline(file, line)) { + if (!line.empty() && line.back() == '\r') { + line.pop_back(); + } + + if (std::regex_search(line, valuePattern)) { + std::string minResLine, maxResLine, executionsLine, geometricAvgLine; + if (getline(file, minResLine) && getline(file, maxResLine) && + getline(file, executionsLine) && getline(file, geometricAvgLine)) { + std::regex minResPattern(R"(MinRes = ([\d\.eE+-]+))"); + std::regex maxResPattern(R"(MaxRes = ([\d\.eE+-]+))"); + std::regex executionsPattern(R"(Executions = (\d+))"); + std::regex geometricAvgPattern(R"(Geometric Average = ([\d\.eE+-]+))"); + + std::smatch minResMatch, maxResMatch, executionsMatch, + geometricAvgMatch; + if (std::regex_search(minResLine, minResMatch, minResPattern) && + std::regex_search(maxResLine, maxResMatch, maxResPattern) && + std::regex_search(executionsLine, executionsMatch, + executionsPattern) && + std::regex_search(geometricAvgLine, geometricAvgMatch, + geometricAvgPattern)) { + data.minRes = stringToDouble(minResMatch[1]); + data.maxRes = stringToDouble(maxResMatch[1]); + data.executions = std::stol(executionsMatch[1]); + data.geometricAvg = stringToDouble(geometricAvgMatch[1]); + } else { + std::string error = + "Failed to parse stats for: Function: " + functionName + + ", BlockIdx: " + std::to_string(blockIdx) + + ", InstIdx: " + std::to_string(instIdx); + llvm_unreachable(error.c_str()); + } + } + + std::regex rangePattern( + R"(Operand\[\d+\] = \[([\d\.eE+-]+), ([\d\.eE+-]+)\])"); + while (getline(file, line)) { + if (std::regex_search(line, newEntryPattern)) { + // Ablation study only: widen the range by `FPOptWidenRange` times + if (FPOptWidenRange != 1) { + double center = (data.minRes + data.maxRes) / 2.0; + double half_range = (data.maxRes - data.minRes) / 2.0; + double new_half_range = half_range * FPOptWidenRange; + data.minRes = center - new_half_range; + data.maxRes = center + new_half_range; + + for (size_t i = 0; i < data.lower.size(); ++i) { + double op_center = (data.lower[i] + data.upper[i]) / 2.0; + double op_half_range = (data.upper[i] - data.lower[i]) / 2.0; + double op_new_half_range = op_half_range * FPOptWidenRange; + data.lower[i] = op_center - op_new_half_range; + data.upper[i] = op_center + op_new_half_range; + } + } + + // All operands have been extracted + return true; + } + + std::smatch rangeMatch; + if (std::regex_search(line, rangeMatch, rangePattern)) { + data.lower.push_back(stringToDouble(rangeMatch[1])); + data.upper.push_back(stringToDouble(rangeMatch[2])); + } + } + } + } + + std::string error = + "Failed to extract value info for: Function: " + functionName + + ", BlockIdx: " + std::to_string(blockIdx) + + ", InstIdx: " + std::to_string(instIdx); + + return false; +} + +bool extractGradFromLog(const std::string &logPath, + const std::string &functionName, size_t blockIdx, + size_t instIdx, double &grad) { + std::ifstream file(logPath); + if (!file.is_open()) { + llvm_unreachable("Failed to open log file"); + } + + std::string line; + std::regex gradPattern("^Grad:" + functionName + ":" + + std::to_string(blockIdx) + ":" + + std::to_string(instIdx) + "$"); + + while (getline(file, line)) { + if (!line.empty() && line.back() == '\r') { + line.pop_back(); + } + + if (std::regex_search(line, gradPattern)) { + + // Extract Grad data + std::regex gradExtractPattern(R"(Grad = ([\d\.eE+-]+))"); + std::smatch gradMatch; + if (getline(file, line) && + std::regex_search(line, gradMatch, gradExtractPattern)) { + grad = stringToDouble(gradMatch[1]); + return true; + } + } + } + + llvm::errs() << "Failed to extract gradient for: Function: " << functionName + << ", BlockIdx: " << blockIdx << ", InstIdx: " << instIdx + << "\n"; + return false; +} + +bool isLogged(const std::string &logPath, const std::string &functionName) { + std::ifstream file(logPath); + if (!file.is_open()) { + assert(0 && "Failed to open log file"); + } + + std::regex functionRegex("^Value:" + functionName); + + std::string line; + while (std::getline(file, line)) { + if (std::regex_search(line, functionRegex)) { + return true; + } + } + + return false; +} + +std::string getPrecondition( + const SmallSet &args, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap) { + std::string preconditions; + + for (const auto &arg : args) { + const auto node = valueToNodeMap.at(symbolToValueMap.at(arg)); + double lower = node->getLowerBound(); + double upper = node->getUpperBound(); + + std::ostringstream lowerStr, upperStr; + lowerStr << std::setprecision(std::numeric_limits::max_digits10) + << std::scientific << lower; + upperStr << std::setprecision(std::numeric_limits::max_digits10) + << std::scientific << upper; + + preconditions += " (<=" + (std::isinf(lower) ? "" : " " + lowerStr.str()) + + " " + arg + + (std::isinf(upper) ? "" : " " + upperStr.str()) + ")"; + } + + return preconditions.empty() ? "TRUE" : "(and" + preconditions + ")"; +} + +// Given the cost budget `FPOptComputationCostBudget`, we want to minimize the +// accuracy cost of the rewritten expressions. +bool accuracyGreedySolver( + SmallVector &AOs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + bool changed = false; + llvm::errs() << "Starting accuracy greedy solver with computation budget: " + << FPOptComputationCostBudget << "\n"; + InstructionCost totalComputationCost = 0; + + for (auto &AO : AOs) { + int bestCandidateIndex = -1; + double bestAccuracyCost = std::numeric_limits::infinity(); + InstructionCost bestCandidateComputationCost; + + for (auto &candidate : enumerate(AO.candidates)) { + size_t i = candidate.index(); + auto candCompCost = AO.getCompCostDelta(i); + auto candAccCost = AO.getAccCostDelta(i); + llvm::errs() << "Candidate " << i << " for " << AO.expr + << " has accuracy cost: " << candAccCost + << " and computation cost: " << candCompCost << "\n"; + + // See if the candidate fits within the computation cost budget + if (totalComputationCost + candCompCost <= FPOptComputationCostBudget) { + // Select the candidate with the lowest accuracy cost + if (candAccCost < bestAccuracyCost) { + llvm::errs() << "Candidate " << i << " selected!\n"; + bestCandidateIndex = i; + bestAccuracyCost = candAccCost; + bestCandidateComputationCost = candCompCost; + } + } + } + + if (bestCandidateIndex != -1) { + AO.apply(bestCandidateIndex, valueToNodeMap, symbolToValueMap); + changed = true; + totalComputationCost += bestCandidateComputationCost; + llvm::errs() << "Updated total computation cost: " << totalComputationCost + << "\n\n"; + } + } + + return changed; +} + +bool accuracyDPSolver( + SmallVector &AOs, SmallVector &ACCs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + bool changed = false; + llvm::errs() << "Starting accuracy DP solver with computation budget: " + << FPOptComputationCostBudget << "\n"; + + using CostMap = std::map; + using SolutionMap = std::map>; + + CostMap costToAccuracyMap; + costToAccuracyMap[0] = 0; + SolutionMap costToSolutionMap; + costToSolutionMap[0] = {}; + CostMap newCostToAccuracyMap; + SolutionMap newCostToSolutionMap; + CostMap prunedCostToAccuracyMap; + SolutionMap prunedCostToSolutionMap; + + std::string cacheFilePath = FPOptCachePath + "/table.json"; + + if (llvm::sys::fs::exists(cacheFilePath)) { + llvm::errs() << "Cache file found. Loading DP tables from cache.\n"; + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFile(cacheFilePath); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Error reading cache file: " << ec.message() << "\n"; + return changed; + } + llvm::StringRef buffer = fileOrErr.get()->getBuffer(); + llvm::Expected jsonOrErr = llvm::json::parse(buffer); + if (!jsonOrErr) { + llvm::errs() << "Error parsing JSON from cache file: " + << llvm::toString(jsonOrErr.takeError()) << "\n"; + return changed; + } + + llvm::json::Object *jsonObj = jsonOrErr->getAsObject(); + if (!jsonObj) { + llvm::errs() << "Invalid JSON format in cache file.\n"; + return changed; + } + + if (llvm::json::Object *costAccMap = + jsonObj->getObject("costToAccuracyMap")) { + for (auto &pair : *costAccMap) { + InstructionCost compCost(std::stoll(pair.first.str())); + double accCost = pair.second.getAsNumber().getValue(); + costToAccuracyMap[compCost] = accCost; + } + } else { + llvm_unreachable("Invalid costToAccuracyMap in cache file."); + } + + if (llvm::json::Object *costSolMap = + jsonObj->getObject("costToSolutionMap")) { + for (auto &pair : *costSolMap) { + InstructionCost compCost(std::stoll(pair.first.str())); + SmallVector solutionSteps; + + llvm::json::Array *stepsArray = pair.second.getAsArray(); + if (!stepsArray) { + llvm::errs() << "Invalid steps array in cache file.\n"; + return changed; + } + + for (llvm::json::Value &stepVal : *stepsArray) { + llvm::json::Object *stepObj = stepVal.getAsObject(); + if (!stepObj) { + llvm_unreachable("Invalid step object in cache file."); + } + + StringRef itemType = stepObj->getString("itemType").getValue(); + size_t candidateIndex = + stepObj->getInteger("candidateIndex").getValue(); + size_t itemIndex = stepObj->getInteger("itemIndex").getValue(); + + if (itemType == "AO") { + if (itemIndex >= AOs.size()) { + llvm_unreachable("Invalid ApplicableOutput index in cache file."); + } + solutionSteps.emplace_back(&AOs[itemIndex], candidateIndex); + } else if (itemType == "ACC") { + if (itemIndex >= ACCs.size()) { + llvm_unreachable("Invalid ApplicableFPCC index in cache file."); + } + solutionSteps.emplace_back(&ACCs[itemIndex], candidateIndex); + } else { + llvm_unreachable("Invalid itemType in cache file."); + } + } + + costToSolutionMap[compCost] = solutionSteps; + } + } else { + llvm::errs() << "costToSolutionMap not found in cache file.\n"; + return changed; + } + + llvm::errs() << "Loaded DP tables from cache.\n"; + + } else { + llvm::errs() << "Cache file not found. Proceeding to solve DP.\n"; + + std::unordered_map aoPtrToIndex; + for (size_t i = 0; i < AOs.size(); ++i) { + aoPtrToIndex[&AOs[i]] = i; + } + std::unordered_map accPtrToIndex; + for (size_t i = 0; i < ACCs.size(); ++i) { + accPtrToIndex[&ACCs[i]] = i; + } + + int AOCounter = 0; + + for (auto &AO : AOs) { + // It is possible to apply zero candidate for an AO. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost currCompCost = pair.first; + double currAccCost = pair.second; + + for (auto &candidate : enumerate(AO.candidates)) { + size_t i = candidate.index(); + auto candCompCost = AO.getCompCostDelta(i); + auto candAccCost = AO.getAccCostDelta(i); + + // Don't ever try to apply a strictly useless candidate + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + + InstructionCost newCompCost = currCompCost + candCompCost; + double newAccCost = currAccCost + candAccCost; + + // if (EnzymePrintFPOpt) + // llvm::errs() << "AO candidate " << i + // << " has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + + if (newCostToAccuracyMap.find(newCompCost) == + newCostToAccuracyMap.end() || + newCostToAccuracyMap[newCompCost] > newAccCost) { + newCostToAccuracyMap[newCompCost] = newAccCost; + newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; + newCostToSolutionMap[newCompCost].emplace_back(&AO, i); + // if (EnzymePrintFPOpt) + // llvm::errs() << "Updating accuracy map (AO candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; + } + } + } + + // TODO: Do not prune AO parts of the DP table since AOs influence ACCs + if (!FPOptEarlyPrune) { + costToAccuracyMap = newCostToAccuracyMap; + costToSolutionMap = newCostToSolutionMap; + + llvm::errs() << "##### Finished processing " << ++AOCounter << " of " + << AOs.size() << " AOs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + continue; + } + + for (const auto &l : newCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + + bool dominated = false; + for (const auto &r : newCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + otherCompCost.getValue().getValue()) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + // if (EnzymePrintFPOpt) + // llvm::errs() << "AO candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; + dominated = true; + break; + } + } + + if (!dominated) { + prunedCostToAccuracyMap[currCompCost] = currAccCost; + prunedCostToSolutionMap[currCompCost] = + newCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); + + llvm::errs() << "##### Finished processing " << ++AOCounter << " of " + << AOs.size() << " AOs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + } + + int ACCCounter = 0; + + for (auto &ACC : ACCs) { + // It is possible to apply zero candidate for an ACC. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost currCompCost = pair.first; + double currAccCost = pair.second; + + for (auto &candidate : enumerate(ACC.candidates)) { + size_t i = candidate.index(); + auto candCompCost = + ACC.getAdjustedCompCostDelta(i, costToSolutionMap[currCompCost]); + auto candAccCost = + ACC.getAdjustedAccCostDelta(i, costToSolutionMap[currCompCost], + valueToNodeMap, symbolToValueMap); + + // Don't ever try to apply a strictly useless candidate + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + + InstructionCost newCompCost = currCompCost + candCompCost; + double newAccCost = currAccCost + candAccCost; + + // if (EnzymePrintFPOpt) + // llvm::errs() << "ACC candidate " << i << " (" + // << candidate.value().desc + // << ") has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + + if (newCostToAccuracyMap.find(newCompCost) == + newCostToAccuracyMap.end() || + newCostToAccuracyMap[newCompCost] > newAccCost) { + newCostToAccuracyMap[newCompCost] = newAccCost; + newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; + newCostToSolutionMap[newCompCost].emplace_back(&ACC, i); + // if (EnzymePrintFPOpt) { + // llvm::errs() << "ACC candidate " << i << " (" + // << candidate.value().desc + // << ") added; has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + // llvm::errs() << "Updating accuracy map (ACC candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; + // } + } + } + } + + for (const auto &l : newCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + + bool dominated = false; + for (const auto &r : newCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + otherCompCost.getValue().getValue()) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + // if (EnzymePrintFPOpt) + // llvm::errs() << "ACC candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; + dominated = true; + break; + } + } + + if (!dominated) { + prunedCostToAccuracyMap[currCompCost] = currAccCost; + prunedCostToSolutionMap[currCompCost] = + newCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); + + llvm::errs() << "##### Finished processing " << ++ACCCounter << " of " + << ACCs.size() << " ACCs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + } + + json::Object jsonObj; + + json::Object costAccMap; + for (const auto &pair : costToAccuracyMap) { + costAccMap[std::to_string(pair.first.getValue().getValue())] = + pair.second; + } + jsonObj["costToAccuracyMap"] = std::move(costAccMap); + + json::Object costSolMap; + for (const auto &pair : costToSolutionMap) { + json::Array stepsArray; + for (const auto &step : pair.second) { + json::Object stepObj; + stepObj["candidateIndex"] = static_cast(step.candidateIndex); + + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + stepObj["itemType"] = "AO"; + size_t index = aoPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } else if constexpr (std::is_same_v) { + stepObj["itemType"] = "ACC"; + size_t index = accPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } + }, + step.item); + stepsArray.push_back(std::move(stepObj)); + } + costSolMap[std::to_string(pair.first.getValue().getValue())] = + std::move(stepsArray); + } + jsonObj["costToSolutionMap"] = std::move(costSolMap); + + std::error_code EC; + llvm::raw_fd_ostream cacheFile(cacheFilePath, EC, llvm::sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing cache file: " << EC.message() << "\n"; + } else { + cacheFile << llvm::formatv("{0:2}", llvm::json::Value(std::move(jsonObj))) + << "\n"; + cacheFile.close(); + llvm::errs() << "DP tables cached to file.\n"; + } + } + + if (EnzymePrintFPOpt) { + if (FPOptShowTable) { + llvm::errs() << "\n*** DP Table ***\n"; + for (const auto &pair : costToAccuracyMap) { + llvm::errs() << "Computation cost: " << pair.first + << ", Accuracy cost: " << pair.second << "\n"; + llvm::errs() << "\tSolution steps: \n"; + for (const auto &step : costToSolutionMap[pair.first]) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + llvm::errs() + << "\t\t" << item->expr << " --(" << step.candidateIndex + << ")-> " << item->candidates[step.candidateIndex].expr + << "\n"; + } else if constexpr (std::is_same_v) { + llvm::errs() << "\t\tACC: " + << item->candidates[step.candidateIndex].desc + << " (#" << step.candidateIndex << ")\n"; + } else { + llvm_unreachable( + "accuracyDPSolver: Unexpected type of solution step"); + } + }, + step.item); + } + } + llvm::errs() << "*** End of DP Table ***\n\n"; + } + llvm::errs() << "*** Critical Computation Costs ***\n"; + for (const auto &pair : costToAccuracyMap) { + llvm::errs() << pair.first << ","; + } + llvm::errs() << "\n"; + llvm::errs() << "*** End of Critical Computation Costs ***\n\n"; + } + + llvm::errs() << "Critical computation cost range: [" + << costToAccuracyMap.begin()->first << ", " + << costToAccuracyMap.rbegin()->first << "]\n"; + + double minAccCost = std::numeric_limits::infinity(); + InstructionCost bestCompCost = 0; + for (const auto &pair : costToAccuracyMap) { + InstructionCost compCost = pair.first; + double accCost = pair.second; + + if (compCost <= FPOptComputationCostBudget && accCost < minAccCost) { + minAccCost = accCost; + bestCompCost = compCost; + } + } + + if (minAccCost == std::numeric_limits::infinity()) { + llvm::errs() << "No solution found within the computation cost budget!\n"; + return changed; + } + + llvm::errs() << "Minimum accuracy cost within budget: " << minAccCost << "\n"; + llvm::errs() << "Computation cost budget used: " << bestCompCost << "\n"; + + if (bestCompCost == 0 && minAccCost == 0) { + llvm::errs() << "WARNING: DP Solver recommended no optimization given the " + "current computation cost budget.\n"; + return changed; + } + + assert(costToSolutionMap.find(bestCompCost) != costToSolutionMap.end() && + "FPOpt DP solver: expected a solution!"); + + llvm::errs() << "\n!!! DP solver: Applying solution ... !!!\n"; + for (const auto &solution : costToSolutionMap[bestCompCost]) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + llvm::errs() << "Applying solution for " << item->expr << " --(" + << solution.candidateIndex << ")-> " + << item->candidates[solution.candidateIndex].expr + << "\n"; + item->apply(solution.candidateIndex, valueToNodeMap, + symbolToValueMap); + } else if constexpr (std::is_same_v) { + llvm::errs() << "Applying solution for ACC: " + << item->candidates[solution.candidateIndex].desc + << " (#" << solution.candidateIndex << ")\n"; + item->apply(solution.candidateIndex); + } else { + llvm_unreachable( + "accuracyDPSolver: Unexpected type of solution step"); + } + }, + solution.item); + changed = true; + } + llvm::errs() << "!!! DP Solver: Solution applied !!!\n\n"; + + return changed; +} + +// Run (our choice of) floating point optimizations on function `F`. +// Return whether or not we change the function. +bool fpOptimize(Function &F, const TargetTransformInfo &TTI) { + const std::string functionName = F.getName().str(); + std::string demangledName = llvm::demangle(functionName); + size_t pos = demangledName.find('('); + if (pos != std::string::npos) { + demangledName = demangledName.substr(0, pos); + } + + std::regex targetFuncRegex(FPOptTargetFuncRegex); + if (!std::regex_match(demangledName, targetFuncRegex)) { + if (EnzymePrintFPOpt) + llvm::errs() << "Skipping function: " << demangledName + << " (demangled) since it does not match the target regex\n"; + return false; + } + + if (!FPOptLogPath.empty()) { + if (!isLogged(FPOptLogPath, functionName)) { + if (EnzymePrintFPOpt) + llvm::errs() << "Skipping matched function: " << functionName + << " since this function is not found in the log\n"; + return false; + } + } + + bool changed = false; + + int symbolCounter = 0; + auto getNextSymbol = [&symbolCounter]() -> std::string { + return "v" + std::to_string(symbolCounter++); + }; + + // Extract change: + + // E1) create map for all instructions I, map[I] = FPLLValue(I) + // E2) for all instructions, if herbiable(I), map[I] = FPNode(operation(I), + // map[operands(I)]) + // E3) floodfill for all starting locations I to find all distinct graphs / + // outputs. + + /* + B1: + x = sin(arg) + + B2: + y = 1 - x * x + + + -> result y = cos(arg)^2 + +B1: + nothing + +B2: + costmp = cos(arg) + y = costmp * costmp + + */ + + std::unordered_map> valueToNodeMap; + std::unordered_map symbolToValueMap; + + llvm::errs() << "FPOpt: Starting Floodfill for " << F.getName() << "\n"; + + for (auto &BB : F) { + for (auto &I : BB) { + if (!herbiable(I)) { + valueToNodeMap[&I] = + std::make_shared(&I, "__nh", "__nh"); // Non-herbiable + if (EnzymePrintFPOpt) + llvm::errs() << "Registered FPLLValue for non-herbiable instruction: " + << I << "\n"; + continue; + } + + std::string dtype; + if (I.getType()->isFloatTy()) { + dtype = "f32"; + } else if (I.getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for instruction"); + } + auto node = std::make_shared(&I, getHerbieOperator(I), dtype); + + auto operands = + isa(I) ? cast(I).args() : I.operands(); + for (auto &operand : operands) { + if (!valueToNodeMap.count(operand)) { + if (auto Arg = dyn_cast(operand)) { + std::string dtype; + if (Arg->getType()->isFloatTy()) { + dtype = "f32"; + } else if (Arg->getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for argument"); + } + valueToNodeMap[operand] = + std::make_shared(Arg, "__arg", dtype); + if (EnzymePrintFPOpt) + llvm::errs() << "Registered FPNode for argument: " << *Arg + << "\n"; + } else if (auto C = dyn_cast(operand)) { + SmallString<10> value; + C->getValueAPF().toString(value); + std::string dtype; + if (C->getType()->isFloatTy()) { + dtype = "f32"; + } else if (C->getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for constant"); + } + valueToNodeMap[operand] = + std::make_shared(value.c_str(), dtype); + if (EnzymePrintFPOpt) + llvm::errs() << "Registered FPNode for " << dtype + << " constant: " << value << "\n"; + } else if (auto GV = dyn_cast(operand)) { + assert( + GV->getType()->getPointerElementType()->isFloatingPointTy() && + "Global variable is not floating point type"); + std::string dtype; + if (GV->getType()->getPointerElementType()->isFloatTy()) { + dtype = "f32"; + } else if (GV->getType()->getPointerElementType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable( + "Unexpected floating point type for global variable"); + } + valueToNodeMap[operand] = + std::make_shared(GV, "__gv", dtype); + if (EnzymePrintFPOpt) + llvm::errs() << "Registered FPNode for global variable: " << *GV + << "\n"; + } else { + assert(0 && "Unknown operand"); + } + } + node->addOperand(valueToNodeMap[operand]); + } + valueToNodeMap[&I] = node; + } + } + + SmallSet component_seen; + SmallVector connected_components; + for (auto &BB : F) { + for (auto &I : BB) { + // Not a herbiable instruction, doesn't make sense to create graph node + // out of. + if (!herbiable(I)) { + if (EnzymePrintFPOpt) + llvm::errs() << "Skipping non-herbiable instruction: " << I << "\n"; + continue; + } + + // Instruction is already in a set + if (component_seen.contains(&I)) { + if (EnzymePrintFPOpt) + llvm::errs() << "Skipping already seen instruction: " << I << "\n"; + continue; + } + + // if (!FPOptLogPath.empty()) { + // auto node = valueToNodeMap[&I]; + // ValueInfo valueInfo; + // auto blockIt = std::find_if( + // I.getFunction()->begin(), I.getFunction()->end(), + // [&](const auto &block) { return &block == I.getParent(); }); + // assert(blockIt != I.getFunction()->end() && "Block not found"); + // size_t blockIdx = std::distance(I.getFunction()->begin(), blockIt); + // auto instIt = + // std::find_if(I.getParent()->begin(), I.getParent()->end(), + // [&](const auto &curr) { return &curr == &I; }); + // assert(instIt != I.getParent()->end() && "Instruction not found"); + // size_t instIdx = std::distance(I.getParent()->begin(), instIt); + + // bool found = extractValueFromLog(FPOptLogPath, functionName, + // blockIdx, + // instIdx, valueInfo); + // if (!found) { + // llvm::errs() << "Instruction " << I << " has no execution + // logged!\n"; continue; + // } + // } + + if (EnzymePrintFPOpt) + llvm::errs() << "Starting floodfill from: " << I << "\n"; + + SmallVector todo; + SetVector input_seen; + SetVector output_seen; + SetVector operation_seen; + todo.push_back(&I); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + assert(valueToNodeMap.count(cur) && "Node not found in valueToNodeMap"); + + // We now can assume that this is a herbiable expression + // Since we can only herbify instructions, let's assert that + assert(isa(cur)); + auto I2 = cast(cur); + + // Don't repeat any instructions we've already seen (to avoid loops + // for phi nodes) + if (operation_seen.contains(I2)) { + if (EnzymePrintFPOpt) + llvm::errs() << "Skipping already seen instruction: " << *I2 + << "\n"; + continue; + } + + // Assume that a herbiable expression can only be in one connected + // component. + assert(!component_seen.contains(cur)); + + if (EnzymePrintFPOpt) + llvm::errs() << "Insert to operation_seen and component_seen: " << *I2 + << "\n"; + operation_seen.insert(I2); + component_seen.insert(cur); + + auto operands = + isa(I2) ? cast(I2)->args() : I2->operands(); + + for (auto &operand_ : enumerate(operands)) { + auto &operand = operand_.value(); + auto i = operand_.index(); + if (!herbiable(*operand)) { + if (EnzymePrintFPOpt) + llvm::errs() << "Non-herbiable input found: " << *operand << "\n"; + + // Don't mark constants as input `llvm::Value`s + if (!isa(operand)) + input_seen.insert(operand); + + // look up error log to get bounds of non-herbiable inputs + if (!FPOptLogPath.empty()) { + ValueInfo valueInfo; + auto blockIt = std::find_if( + I2->getFunction()->begin(), I2->getFunction()->end(), + [&](const auto &block) { return &block == I2->getParent(); }); + assert(blockIt != I2->getFunction()->end() && "Block not found"); + size_t blockIdx = + std::distance(I2->getFunction()->begin(), blockIt); + auto instIt = + std::find_if(I2->getParent()->begin(), I2->getParent()->end(), + [&](const auto &curr) { return &curr == I2; }); + assert(instIt != I2->getParent()->end() && + "Instruction not found"); + size_t instIdx = std::distance(I2->getParent()->begin(), instIt); + + extractValueFromLog(FPOptLogPath, functionName, blockIdx, instIdx, + valueInfo); + auto node = valueToNodeMap[operand]; + node->updateBounds(valueInfo.lower[i], valueInfo.upper[i]); + + if (EnzymePrintFPOpt) { + llvm::errs() << "Range of " << *operand << " is [" + << node->getLowerBound() << ", " + << node->getUpperBound() << "]\n"; + } + } + } else { + if (EnzymePrintFPOpt) + llvm::errs() << "Adding operand to todo list: " << *operand + << "\n"; + todo.push_back(operand); + } + } + + for (auto U : I2->users()) { + if (auto I3 = dyn_cast(U)) { + if (!herbiable(*I3)) { + if (EnzymePrintFPOpt) + llvm::errs() << "Output instruction found: " << *I2 << "\n"; + output_seen.insert(I2); + } else { + if (EnzymePrintFPOpt) + llvm::errs() << "Adding user to todo list: " << *I3 << "\n"; + todo.push_back(I3); + } + } + } + } + + // Don't bother with graphs without any herbiable operations + if (!operation_seen.empty()) { + if (EnzymePrintFPOpt) { + llvm::errs() << "Found a connected component with " + << operation_seen.size() << " operations and " + << input_seen.size() << " inputs and " + << output_seen.size() << " outputs\n"; + + llvm::errs() << "Inputs:\n"; + + for (auto &input : input_seen) { + llvm::errs() << *input << "\n"; + } + + llvm::errs() << "Outputs:\n"; + for (auto &output : output_seen) { + llvm::errs() << *output << "\n"; + } + + llvm::errs() << "Operations:\n"; + for (auto &operation : operation_seen) { + llvm::errs() << *operation << "\n"; + } + } + + // TODO: Further check + if (operation_seen.size() == 1) { + if (EnzymePrintFPOpt) + llvm::errs() << "Skipping trivial connected component\n"; + continue; + } + + FPCC origCC{input_seen, output_seen, operation_seen}; + SmallVector newCCs; + splitFPCC(origCC, newCCs); + + for (auto &CC : newCCs) { + for (auto *input : CC.inputs) { + valueToNodeMap[input]->markAsInput(); + } + } + + if (!FPOptLogPath.empty()) { + for (auto &CC : newCCs) { + // Extract grad and value info for all instructions. + for (auto &op : CC.operations) { + double grad = 0; + auto blockIt = std::find_if( + op->getFunction()->begin(), op->getFunction()->end(), + [&](const auto &block) { return &block == op->getParent(); }); + assert(blockIt != op->getFunction()->end() && "Block not found"); + size_t blockIdx = + std::distance(op->getFunction()->begin(), blockIt); + auto instIt = + std::find_if(op->getParent()->begin(), op->getParent()->end(), + [&](const auto &curr) { return &curr == op; }); + assert(instIt != op->getParent()->end() && + "Instruction not found"); + size_t instIdx = std::distance(op->getParent()->begin(), instIt); + bool found = extractGradFromLog(FPOptLogPath, functionName, + blockIdx, instIdx, grad); + + auto node = valueToNodeMap[op]; + node->grad = grad; + + if (found) { + ValueInfo valueInfo; + extractValueFromLog(FPOptLogPath, functionName, blockIdx, + instIdx, valueInfo); + node->executions = valueInfo.executions; + node->geometricAvg = valueInfo.geometricAvg; + node->updateBounds(valueInfo.minRes, valueInfo.maxRes); + + if (EnzymePrintFPOpt) { + llvm::errs() + << "Range of " << *op << " is [" << node->getLowerBound() + << ", " << node->getUpperBound() << "]\n"; + } + + if (EnzymePrintFPOpt) + llvm::errs() + << "Grad of " << *op << " is: " << node->grad << "\n" + << "Execution count of " << *op + << " is: " << node->executions << "\n"; + } else { // Unknown bounds + if (EnzymePrintFPOpt) + llvm::errs() + << "Grad of " << *op + << " are not found in the log; using 0 instead\n"; + } + } + } + } + + connected_components.insert(connected_components.end(), newCCs.begin(), + newCCs.end()); + } + } + } + + llvm::errs() << "FPOpt: Found " << connected_components.size() + << " connected components in " << F.getName() << "\n"; + + // 1) Identify subgraphs of the computation which can be entirely represented + // in herbie-style arithmetic + // 2) Make the herbie FP-style expression by + // converting llvm instructions into herbie string (FPNode ....) + if (connected_components.empty()) { + if (EnzymePrintFPOpt) + llvm::errs() << "No herbiable connected components found\n"; + return false; + } + + SmallVector AOs; + SmallVector ACCs; + + int componentCounter = 0; + + for (auto &component : connected_components) { + assert(component.inputs.size() > 0 && "No inputs found for component"); + if (FPOptEnableHerbie) { + for (const auto &input : component.inputs) { + auto node = valueToNodeMap[input]; + if (node->op == "__const") { + // Constants don't need a symbol + continue; + } + if (!node->hasSymbol()) { + node->symbol = getNextSymbol(); + } + symbolToValueMap[node->symbol] = input; + if (EnzymePrintFPOpt) + llvm::errs() << "assigning symbol: " << node->symbol << " to " + << *input << "\n"; + } + + std::vector herbieInputs; + std::vector newAOs; + int outputCounter = 0; + + assert(component.outputs.size() > 0 && "No outputs found for component"); + for (auto &output : component.outputs) { + // 3) run fancy opts + double grad = valueToNodeMap[output]->grad; + unsigned executions = valueToNodeMap[output]->executions; + + // TODO: For now just skip if grad is 0 + if (!FPOptLogPath.empty() && grad == 0.) { + llvm::errs() << "Skipping algebraic rewriting for " << *output + << " since gradient is 0\n"; + continue; + } + + // TODO: Herbie properties + std::string expr = + valueToNodeMap[output]->toFullExpression(valueToNodeMap); + SmallSet args; + getUniqueArgs(expr, args); + + std::string properties = ":herbie-conversions ([binary64 binary32])"; + if (valueToNodeMap[output]->dtype == "f32") { + properties += " :precision binary32"; + } else if (valueToNodeMap[output]->dtype == "f64") { + properties += " :precision binary64"; + } else { + llvm_unreachable("Unexpected dtype"); + } + + if (!FPOptLogPath.empty()) { + std::string precondition = + getPrecondition(args, valueToNodeMap, symbolToValueMap); + properties += " :pre " + precondition; + } + + ApplicableOutput AO(component, output, expr, grad, executions, TTI); + properties += " :name \"" + std::to_string(outputCounter++) + "\""; + + std::string argStr; + for (const auto &arg : args) { + if (!argStr.empty()) + argStr += " "; + argStr += arg; + } + + std::string herbieInput = + "(FPCore (" + argStr + ") " + properties + " " + expr + ")"; + if (EnzymePrintHerbie) + llvm::errs() << "Herbie input:\n" << herbieInput << "\n"; + + herbieInputs.push_back(herbieInput); + newAOs.push_back(AO); + } + + if (!herbieInputs.empty()) { + if (!improveViaHerbie(herbieInputs, newAOs, F.getParent(), TTI, + valueToNodeMap, symbolToValueMap, + componentCounter)) { + if (EnzymePrintHerbie) + llvm::errs() << "Failed to optimize expressions using Herbie!\n"; + } + + AOs.insert(AOs.end(), newAOs.begin(), newAOs.end()); + } + } + + if (FPOptEnablePT) { + // Sort `component.operations` by the gradient and construct + // `PrecisionChange`s. + ApplicableFPCC ACC(component, TTI); + auto *o0 = component.outputs[0]; + ACC.executions = valueToNodeMap[o0]->executions; + + const SmallVector precTypes{ + PrecisionChangeType::FP32, + PrecisionChangeType::FP64, + }; + + // TODO: since we are only doing FP64 -> FP32, we can skip more expensive + // operations for now. + static const std::unordered_set Funcs = { + "sin", "cos", "tan", "exp", "log", "sqrt", "expm1", + "log1p", "cbrt", "pow", "fabs", "hypot", "fma"}; + + SmallVector operations; + for (auto *I : component.operations) { + assert(isa(valueToNodeMap[I].get()) && + "Corrupted FPNode for original instructions"); + auto node = cast(valueToNodeMap[I].get()); + if (Funcs.count(node->op) != 0) { + operations.push_back(node); + } + } + + // Sort operations by the gradient + llvm::sort(operations, [](const auto &a, const auto &b) { + return std::fabs(a->grad * a->geometricAvg) < + std::fabs(b->grad * b->geometricAvg); + }); + + // Create PrecisionChanges for 0-10%, 0-20%, ..., up to 0-100% + for (int percent = 10; percent <= 100; percent += 10) { + size_t numToChange = operations.size() * percent / 100; + + SetVector opsToChange(operations.begin(), + operations.begin() + numToChange); + + if (EnzymePrintFPOpt && !opsToChange.empty()) { + llvm::errs() << "Created PrecisionChange for " << percent + << "% of Funcs (" << numToChange << ")\n"; + llvm::errs() << "Subset sensitivity score range: [" + << std::fabs(opsToChange.front()->grad * + opsToChange.front()->geometricAvg) + << ", " + << std::fabs(opsToChange.back()->grad * + opsToChange.back()->geometricAvg) + << "]\n"; + } + + for (auto prec : precTypes) { + std::string precStr = getPrecisionChangeTypeString(prec).str(); + std::string desc = + "Funcs 0% -- " + std::to_string(percent) + "% -> " + precStr; + + PrecisionChange change( + opsToChange, + getPrecisionChangeType(component.outputs[0]->getType()), prec); + + SmallVector changes{std::move(change)}; + PTCandidate candidate{std::move(changes), desc}; + candidate.CompCost = getCompCost(component, TTI, candidate); + ACC.candidates.push_back(std::move(candidate)); + } + } + + // Create candidates by considering all operations without filtering + SmallVector allOperations; + for (auto *I : component.operations) { + assert(isa(valueToNodeMap[I].get()) && + "Corrupted FPNode for original instructions"); + auto node = cast(valueToNodeMap[I].get()); + allOperations.push_back(node); + } + + // Sort all operations by the gradient + llvm::sort(allOperations, [](const auto &a, const auto &b) { + return std::fabs(a->grad * a->geometricAvg) < + std::fabs(b->grad * b->geometricAvg); + }); + + // Create PrecisionChanges for 0-10%, 0-20%, ..., up to 0-100% + for (int percent = 10; percent <= 100; percent += 10) { + size_t numToChange = allOperations.size() * percent / 100; + + SetVector opsToChange(allOperations.begin(), + allOperations.begin() + numToChange); + + if (EnzymePrintFPOpt && !opsToChange.empty()) { + llvm::errs() << "Created PrecisionChange for " << percent + << "% of all operations (" << numToChange << ")\n"; + llvm::errs() << "Subset sensitivity score range: [" + << std::fabs(opsToChange.front()->grad * + opsToChange.front()->geometricAvg) + << ", " + << std::fabs(opsToChange.back()->grad * + opsToChange.back()->geometricAvg) + << "]\n"; + } + + for (auto prec : precTypes) { + std::string precStr = getPrecisionChangeTypeString(prec).str(); + std::string desc = + "All 0% -- " + std::to_string(percent) + "% -> " + precStr; + + PrecisionChange change( + opsToChange, + getPrecisionChangeType(component.outputs[0]->getType()), prec); + + SmallVector changes{std::move(change)}; + PTCandidate candidate{std::move(changes), desc}; + candidate.CompCost = getCompCost(component, TTI, candidate); + ACC.candidates.push_back(std::move(candidate)); + } + } + + setUnifiedAccuracyCost(ACC, valueToNodeMap, symbolToValueMap); + + ACCs.push_back(std::move(ACC)); + } + llvm::errs() << "##### Finished synthesizing candidates for " + << ++componentCounter << " of " << connected_components.size() + << " connected components! #####\n"; + } + + // Perform rewrites + if (EnzymePrintFPOpt) { + if (FPOptEnableHerbie) { + for (auto &AO : AOs) { + // TODO: Solver + // Available Parameters: + // 1. gradients at the output llvm::Value + // 2. costs of the potential rewrites from Herbie (lower is preferred) + // 3. percentage accuracies of potential rewrites (higher is better) + // 4*. TTI costs of potential rewrites (TODO: need to consider branches) + // 5*. Custom error estimates of potential rewrites (TODO) + + llvm::errs() << "\n################################\n"; + llvm::errs() << "Initial AccuracyCost: " << AO.initialAccCost << "\n"; + llvm::errs() << "Initial ComputationCost: " << AO.initialCompCost + << "\n"; + llvm::errs() << "Initial HerbieCost: " << AO.initialHerbieCost << "\n"; + llvm::errs() << "Initial HerbieAccuracy: " << AO.initialHerbieAccuracy + << "\n"; + llvm::errs() << "Initial Expression: " << AO.expr << "\n"; + llvm::errs() << "Grad: " << AO.grad << "\n\n"; + llvm::errs() << "Candidates:\n"; + llvm::errs() << "Δ AccCost\t\tΔ " + "CompCost\t\tHerbieCost\t\tAccuracy\t\tExpression\n"; + llvm::errs() << "--------------------------------\n"; + for (size_t i = 0; i < AO.candidates.size(); ++i) { + auto &candidate = AO.candidates[i]; + llvm::errs() << AO.getAccCostDelta(i) << "\t\t" + << AO.getCompCostDelta(i) << "\t\t" + << candidate.herbieCost << "\t\t" + << candidate.herbieAccuracy << "\t\t" << candidate.expr + << "\n"; + } + llvm::errs() << "################################\n\n"; + } + } + if (FPOptEnablePT) { + for (auto &ACC : ACCs) { + llvm::errs() << "\n################################\n"; + llvm::errs() << "Initial AccuracyCost: " << ACC.initialAccCost << "\n"; + llvm::errs() << "Initial ComputationCost: " << ACC.initialCompCost + << "\n"; + llvm::errs() << "Candidates:\n"; + llvm::errs() << "Δ AccCost\t\tΔ CompCost\t\tDescription\n" + << "---------------------------\n"; + for (size_t i = 0; i < ACC.candidates.size(); ++i) { + auto &candidate = ACC.candidates[i]; + llvm::errs() << ACC.getAccCostDelta(i) << "\t\t" + << ACC.getCompCostDelta(i) << "\t\t" << candidate.desc + << "\n"; + } + llvm::errs() << "################################\n\n"; + } + } + } + + if (!FPOptEnableSolver) { + if (FPOptEnableHerbie) { + for (auto &AO : AOs) { + AO.apply(0, valueToNodeMap, symbolToValueMap); + changed = true; + } + } + + // TODO: just for testing + if (FPOptEnablePT) { + for (auto &ACC : ACCs) { + ACC.apply(0); + changed = true; + } + } + } else { + // TODO: Solver + if (FPOptLogPath.empty()) { + llvm::errs() << "FPOpt: Solver enabled but no log file is provided\n"; + return false; + } + if (FPOptSolverType == "greedy") { + changed = accuracyGreedySolver(AOs, valueToNodeMap, symbolToValueMap); + } else if (FPOptSolverType == "dp") { + changed = accuracyDPSolver(AOs, ACCs, valueToNodeMap, symbolToValueMap); + } else { + llvm::errs() << "FPOpt: Unknown solver type: " << FPOptSolverType << "\n"; + return false; + } + } + + llvm::errs() << "FPOpt: Finished optimizing " << F.getName() << "\n"; + + // Cleanup + if (changed) { + for (auto &component : connected_components) { + if (component.outputs_rewritten != component.outputs.size()) { + if (EnzymePrintFPOpt) + llvm::errs() << "Skip erasing a connect component: only rewrote " + << component.outputs_rewritten << " of " + << component.outputs.size() << " outputs\n"; + continue; // Intermediate operations cannot be erased safely + } + for (auto *I : component.operations) { + if (EnzymePrintFPOpt) + llvm::errs() << "Erasing: " << *I << "\n"; + if (!I->use_empty()) { + I->replaceAllUsesWith(UndefValue::get(I->getType())); + } + I->eraseFromParent(); + } + } + + llvm::errs() << "FPOpt: Finished cleaning up " << F.getName() << "\n"; + } + + if (EnzymePrintFPOpt) { + llvm::errs() << "FPOpt: Finished Optimization\n"; + // F.print(llvm::errs()); + } + + return changed; +} + +namespace { + +class FPOpt final : public FunctionPass { +public: + static char ID; + FPOpt() : FunctionPass(ID) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + FunctionPass::getAnalysisUsage(AU); + } + bool runOnFunction(Function &F) override { + auto &TTI = getAnalysis().getTTI(F); + return fpOptimize(F, TTI); + } +}; + +} // namespace + +char FPOpt::ID = 0; + +static RegisterPass X("fp-opt", + "Run Enzyme/Herbie Floating point optimizations"); + +FunctionPass *createFPOptPass() { return new FPOpt(); } + +#include +#include + +#include "llvm/IR/LegacyPassManager.h" + +extern "C" void AddFPOptPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createFPOptPass()); +} + +FPOptNewPM::Result FPOptNewPM::run(llvm::Module &M, + llvm::ModuleAnalysisManager &MAM) { + bool changed = false; + FunctionAnalysisManager &FAM = + MAM.getResult(M).getManager(); + for (auto &F : M) { + if (!F.isDeclaration()) { + const auto &TTI = FAM.getResult(F); + changed |= fpOptimize(F, TTI); + } + } + + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} +llvm::AnalysisKey FPOptNewPM::Key; diff --git a/enzyme/Enzyme/Herbie.h b/enzyme/Enzyme/Herbie.h new file mode 100644 index 000000000000..8f6d3a72cd6c --- /dev/null +++ b/enzyme/Enzyme/Herbie.h @@ -0,0 +1,36 @@ +#ifndef ENZYME_HERBIE_H +#define ENZYME_HERBIE_H + +#include + +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassPlugin.h" + +#include "llvm/Support/CommandLine.h" + +namespace llvm { +class FunctionPass; +} + +extern "C" { +extern llvm::cl::opt EnzymeEnableFPOpt; +} + +llvm::FunctionPass *createFPOptPass(); + +class FPOptNewPM final : public llvm::AnalysisInfoMixin { + friend struct llvm::AnalysisInfoMixin; + +private: + static llvm::AnalysisKey Key; + +public: + using Result = llvm::PreservedAnalyses; + FPOptNewPM() {} + + Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM); + + static bool isRequired() { return true; } +}; + +#endif // ENZYME_HERBIE_H diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 2e78f7c083cd..7ade6b516c17 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -3618,6 +3618,44 @@ llvm::Value *get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res) { return absres; } +llvm::Function *getLogFunction(llvm::Module *M, llvm::StringRef demangledName) { + if (demangledName != "enzymeLogError" && demangledName != "enzymeLogGrad" && + demangledName != "enzymeLogValue") { + llvm_unreachable("Unknown log function"); + } + for (llvm::Function &F : *M) { + if (startsWith(llvm::demangle(F.getName().str()), demangledName)) { + return &F; + } + } + return nullptr; // Return nullptr if no matching function is found +} + +std::string getLogIdentifier(llvm::Instruction &I) { + assert(I.hasMetadata("enzyme_preprocess_origin")); + auto *CMD = cast( + I.getMetadata("enzyme_preprocess_origin")->getOperand(0)); + uintptr_t ptrValue = cast(CMD->getValue())->getZExtValue(); + auto *OI = reinterpret_cast(ptrValue); + + llvm::StringRef functionName = OI->getFunction()->getName(); + int blockIdx = -1, instIdx = -1; + auto blockIt = llvm::find_if(*OI->getFunction(), [&](const auto &block) { + return &block == OI->getParent(); + }); + if (blockIt != OI->getFunction()->end()) { + blockIdx = std::distance(OI->getFunction()->begin(), blockIt); + } + auto instIt = llvm::find_if(*OI->getParent(), + [&](const auto &curr) { return &curr == OI; }); + if (instIt != OI->getParent()->end()) { + instIdx = std::distance(OI->getParent()->begin(), instIt); + } + + return functionName.str() + ":" + std::to_string(blockIdx) + ":" + + std::to_string(instIdx); +} + llvm::Value *EmitNoDerivativeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index f825af47607c..2fcf98641082 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -55,6 +55,8 @@ #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/Demangle/Demangle.h" + #include #include @@ -561,6 +563,8 @@ static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, } llvm::Value *get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res); +llvm::Function *getLogFunction(llvm::Module *M, llvm::StringRef demangledName); +std::string getLogIdentifier(llvm::Instruction &I); static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode) { std::set seen; diff --git a/enzyme/benchmarks/ReverseMode/ode/Makefile.make b/enzyme/benchmarks/ReverseMode/ode/Makefile.make index 64127c682fc5..42038d29e9ca 100644 --- a/enzyme/benchmarks/ReverseMode/ode/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/ode/Makefile.make @@ -1,28 +1,34 @@ -# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" make -B ode-unopt.ll ode-opt.ll ode.o results.txt VERBOSE=1 -f %s +# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" LOAD="%loadEnzyme" LOADCLANG="%loadClangEnzyme" make -B tuned.exe VERBOSE=1 -f %s .PHONY: clean clean: - rm -f *.ll *.o results.txt + rm -f *.exe *.o results.txt ode.txt -ode-adept-unopt.ll: ode-adept.cpp - clang++ $(BENCH) $^ -O2 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm - #clang++ $(BENCH) $^ -O1 -Xclang -disable-llvm-passes -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm +logger.exe: ode.cpp fp-logger.cpp + clang++ $(BENCH) $(LOADCLANG) ode.cpp fp-logger.cpp -O3 -mllvm -enzyme-inline -ffast-math -fno-finite-math-only -o $@ -DLOGGING -ode-unopt.ll: ode.cpp - clang++ $(BENCH) $^ -O2 -fno-use-cxa-atexit -fno-exceptions -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm - #clang++ $(BENCH) $^ -O1 -Xclang -disable-llvm-passes -fno-use-cxa-atexit -fno-exceptions -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm -Xclang -new-struct-path-tbaa - -ode-raw.ll: ode-adept-unopt.ll ode-unopt.ll - opt ode-unopt.ll $(LOAD) -enzyme -o ode-enzyme.ll -S - llvm-link ode-adept-unopt.ll ode-enzyme.ll -o $@ - -%-opt.ll: %-raw.ll - opt $^ -o $@ -S - #opt $^ -O2 -o $@ -S - -ode.o: ode-opt.ll - clang++ -O2 $^ -o $@ $(BENCHLINK) -lm - -results.txt: ode.o +ode.txt: logger.exe ./$^ 1000000 | tee $@ + +tuned.exe: ode.cpp ode.txt + clang++ $(BENCH) $(LOADCLANG) ode.cpp -O3 -ffast-math -fno-finite-math-only -o $@ \ + -mllvm -enzyme-inline \ + -mllvm -enzyme-enable-fpopt \ + -mllvm -fpopt-log-path=ode.txt \ + -mllvm -fpopt-enable-solver \ + -mllvm -fpopt-enable-pt \ + -mllvm -fpopt-target-func-regex=foobar \ + -mllvm -fpopt-comp-cost-budget=0 \ + -mllvm -fpopt-num-samples=1024 \ + -mllvm -fpopt-cost-dom-thres=0.0 \ + -mllvm -fpopt-acc-dom-thres=0.0 \ + -mllvm -enzyme-print-fpopt \ + -mllvm -fpopt-show-table \ + -mllvm -fpopt-cache-path=cache \ + -mllvm -herbie-timeout=1000 \ + -mllvm -herbie-num-threads=12 \ + -mllvm --fpopt-cost-model-path=cm.csv + +results.txt: tuned.exe + ./$^ 1000000 | tee $@ \ No newline at end of file diff --git a/enzyme/benchmarks/ReverseMode/ode/cm.csv b/enzyme/benchmarks/ReverseMode/ode/cm.csv new file mode 100644 index 000000000000..2122011dfbba --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode/cm.csv @@ -0,0 +1,40 @@ +fneg,float,22 +fadd,float,22 +fsub,float,22 +fmul,float,22 +fdiv,float,22 +fcmp,float,16 +fpext_float_to_double,float,22 +sin,float,100 +cos,float,103 +tan,float,246 +exp,float,142 +log,float,83 +sqrt,float,33 +expm1,float,106 +log1p,float,110 +cbrt,float,584 +pow,float,159 +fabs,float,20 +hypot,float,608 +fma,float,75 +fneg,double,19 +fadd,double,19 +fsub,double,19 +fmul,double,19 +fdiv,double,28 +fcmp,double,15 +fptrunc_double_to_float,double,19 +sin,double,829 +cos,double,859 +tan,double,998 +exp,double,192 +log,double,104 +sqrt,double,52 +expm1,double,75 +log1p,double,118 +cbrt,double,244 +pow,double,225 +fabs,double,20 +hypot,double,382 +fma,double,74 \ No newline at end of file diff --git a/enzyme/benchmarks/ReverseMode/ode/fp-logger.cpp b/enzyme/benchmarks/ReverseMode/ode/fp-logger.cpp new file mode 100644 index 000000000000..545e31d2564f --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode/fp-logger.cpp @@ -0,0 +1,165 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fp-logger.hpp" + +class ValueInfo { +public: + double minRes = std::numeric_limits::max(); + double maxRes = std::numeric_limits::lowest(); + std::vector minOperands; + std::vector maxOperands; + unsigned executions = 0; + double logSum = 0.0; + unsigned logCount = 0; + + void update(double res, const double *operands, unsigned numOperands) { + minRes = std::min(minRes, res); + maxRes = std::max(maxRes, res); + if (minOperands.empty()) { + minOperands.resize(numOperands, std::numeric_limits::max()); + maxOperands.resize(numOperands, std::numeric_limits::lowest()); + } + for (unsigned i = 0; i < numOperands; ++i) { + minOperands[i] = std::min(minOperands[i], operands[i]); + maxOperands[i] = std::max(maxOperands[i], operands[i]); + } + ++executions; + + if (!std::isnan(res)) { + logSum += std::log1p(std::fabs(res)); + ++logCount; + } + } + + double getGeometricAverage() const { + if (logCount == 0) { + return 0.; + } + return std::expm1(logSum / logCount); + } +}; + +class ErrorInfo { +public: + double minErr = std::numeric_limits::max(); + double maxErr = std::numeric_limits::lowest(); + + void update(double err) { + minErr = std::min(minErr, err); + maxErr = std::max(maxErr, err); + } +}; + +class GradInfo { +public: + double logSum = 0.0; + unsigned count = 0; + + void update(double grad) { + if (!std::isnan(grad)) { + logSum += std::log1p(std::fabs(grad)); + ++count; + } + } + + double getGeometricAverage() const { + if (count == 0) { + return 0.; + } + return std::expm1(logSum / count); + } +}; + +class Logger { +private: + std::unordered_map valueInfo; + std::unordered_map errorInfo; + std::unordered_map gradInfo; + +public: + void updateValue(const std::string &id, double res, unsigned numOperands, + const double *operands) { + auto &info = valueInfo.emplace(id, ValueInfo()).first->second; + info.update(res, operands, numOperands); + } + + void updateError(const std::string &id, double err) { + auto &info = errorInfo.emplace(id, ErrorInfo()).first->second; + info.update(err); + } + + void updateGrad(const std::string &id, double grad) { + auto &info = gradInfo.emplace(id, GradInfo()).first->second; + info.update(grad); + } + + void print() const { + std::cout << std::scientific + << std::setprecision(std::numeric_limits::max_digits10); + + for (const auto &pair : valueInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Value:" << id << "\n"; + std::cout << "\tMinRes = " << info.minRes << "\n"; + std::cout << "\tMaxRes = " << info.maxRes << "\n"; + std::cout << "\tExecutions = " << info.executions << "\n"; + std::cout << "\tGeometric Average = " << info.getGeometricAverage() + << "\n"; + for (unsigned i = 0; i < info.minOperands.size(); ++i) { + std::cout << "\tOperand[" << i << "] = [" << info.minOperands[i] << ", " + << info.maxOperands[i] << "]\n"; + } + } + + for (const auto &pair : errorInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Error:" << id << "\n"; + std::cout << "\tMinErr = " << info.minErr << "\n"; + std::cout << "\tMaxErr = " << info.maxErr << "\n"; + } + + for (const auto &pair : gradInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Grad:" << id << "\n"; + std::cout << "\tGrad = " << info.getGeometricAverage() << "\n"; + } + } +}; + +Logger *logger = nullptr; + +void initializeLogger() { logger = new Logger(); } + +void destroyLogger() { + delete logger; + logger = nullptr; +} + +void printLogger() { logger->print(); } + +void enzymeLogError(const char *id, double err) { + assert(logger && "Logger is not initialized"); + logger->updateError(id, err); +} + +void enzymeLogGrad(const char *id, double grad) { + assert(logger && "Logger is not initialized"); + logger->updateGrad(id, grad); +} + +void enzymeLogValue(const char *id, double res, unsigned numOperands, + double *operands) { + assert(logger && "Logger is not initialized"); + logger->updateValue(id, res, numOperands, operands); +} \ No newline at end of file diff --git a/enzyme/benchmarks/ReverseMode/ode/fp-logger.hpp b/enzyme/benchmarks/ReverseMode/ode/fp-logger.hpp new file mode 100644 index 000000000000..657aa947ae36 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode/fp-logger.hpp @@ -0,0 +1,8 @@ +void initializeLogger(); +void destroyLogger(); +void printLogger(); + +void enzymeLogError(const char *id, double err); +void enzymeLogGrad(const char *id, double grad); +void enzymeLogValue(const char *id, double res, unsigned numOperands, + double *operands); diff --git a/enzyme/benchmarks/ReverseMode/ode/ode-adept.cpp b/enzyme/benchmarks/ReverseMode/ode/ode-adept.cpp deleted file mode 100644 index c0075e1bf159..000000000000 --- a/enzyme/benchmarks/ReverseMode/ode/ode-adept.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -using adept::adouble; - -static float tdiff(struct timeval *start, struct timeval *end) { - return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec); -} - -#define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS -#define BOOST_NO_EXCEPTIONS -#include -#include - -#include - -#include - -using namespace std; -using namespace boost::numeric::odeint; - -#include - -double foobar(double t, uint64_t iters); - -typedef boost::array< adouble , 1 > astate_type; - -void alorenz( const astate_type &x , astate_type &dxdt , adouble t ) -{ - const double a = 1.2; - dxdt[0] = -a * x[0]; -} - -adouble afoobar(adouble t, uint64_t iters) { - astate_type x = { 1.0 }; // initial conditions - - adouble start = 0.0; - adouble step = t/adouble(iters); - typedef controlled_runge_kutta< runge_kutta_dopri5< astate_type , typename astate_type::value_type , astate_type , adouble > > stepper_type; - //typedef euler< astate_type , typename astate_type::value_type , astate_type , adouble > stepper_type; - integrate_const( stepper_type(), alorenz , x , start , t, step ); - - //x[0] += -1.2 * step * x[0]; - - //printf("final result t=%f x(t)=%f, exp(-1.2* t)=%f\n", t, x[0], exp(- 1.2 * t)); - return x[0]; -} - -double afoobar_and_gradient(double xin, double& xgrad, uint64_t iters) { - adept::Stack stack; - adouble x = xin; - stack.new_recording(); - adouble y = afoobar(x, iters); - y.set_gradient(1.0); - stack.compute_adjoint(); - xgrad = x.get_gradient(); - return y.value(); -} - -void adept_sincos(double inp, uint64_t iters) { - { - struct timeval start, end; - gettimeofday(&start, NULL); - - double res = foobar(inp, iters); - - gettimeofday(&end, NULL); - printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res); - } - - { - struct timeval start, end; - gettimeofday(&start, NULL); - - adept::Stack stack; - // stack.new_recording(); - adouble resa = afoobar(inp, iters); - double res = resa.value(); - - gettimeofday(&end, NULL); - printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res); - } - - { - struct timeval start, end; - gettimeofday(&start, NULL); - - double res2 = 0; - afoobar_and_gradient(inp, res2, iters); - - gettimeofday(&end, NULL); - printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), res2); - } -} diff --git a/enzyme/benchmarks/ReverseMode/ode/ode.cpp b/enzyme/benchmarks/ReverseMode/ode/ode.cpp index 2473bf355cf7..423f3e42425b 100644 --- a/enzyme/benchmarks/ReverseMode/ode/ode.cpp +++ b/enzyme/benchmarks/ReverseMode/ode/ode.cpp @@ -1,34 +1,42 @@ +#include #include #include #include -#include -#include -#include -#include #include +#include -template -Return __enzyme_autodiff(T...); +#ifdef LOGGING +#include "fp-logger.hpp" + +void thisIsNeverCalledAndJustForTheLinker() { + enzymeLogError("", 0.0); + enzymeLogGrad("", 0.0); + enzymeLogValue("", 0.0, 2, nullptr); +} + +template Return __enzyme_autodiff(T...); +#endif static float tdiff(struct timeval *start, struct timeval *end) { - return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec); + return (end->tv_sec - start->tv_sec) + 1e-6 * (end->tv_usec - start->tv_usec); } #define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS #define BOOST_NO_EXCEPTIONS -#include #include +#include #include -#include #include -void boost::throw_exception(std::exception const & e) { - //do nothing +#include +void boost::throw_exception(std::exception const &e) { + //do nothing } #if BOOST_VERSION >= 107300 -void boost::throw_exception(std::exception const & e, boost::source_location const & loc) { - //do nothing +void boost::throw_exception(std::exception const &e, + boost::source_location const &loc) { + //do nothing } #endif @@ -37,70 +45,66 @@ using namespace boost::numeric::odeint; #include -typedef boost::array< double , 1 > state_type; +typedef boost::array state_type; -void lorenz( const state_type &x , state_type &dxdt , double t ) -{ - const double a = 1.2; - dxdt[0] = -a * x[0]; +void lorenz(const state_type &x, state_type &dxdt, double t) { + const double a = 1.2; + dxdt[0] = -a * x[0]; } - double foobar(double t, uint64_t iters) { - state_type x = { 1.0 }; // initial conditions + state_type x = {1.0}; // initial conditions - typedef controlled_runge_kutta< runge_kutta_dopri5< state_type , typename state_type::value_type , state_type , double > > stepper_type; - //typedef euler< state_type , typename state_type::value_type , state_type , double > stepper_type; - integrate_const( stepper_type(), lorenz , x , 0.0 , t, t/iters ); + typedef controlled_runge_kutta> + stepper_type; + //typedef euler< state_type , typename state_type::value_type , state_type , double > stepper_type; + integrate_const(stepper_type(), lorenz, x, 0.0, t, t / iters); - //printf("final result t=%f x(t)=%f, exp(-1.2* t)=%f\n", t, x[0], exp(- 1.2 * t)); - return x[0]; + //printf("final result t=%f x(t)=%f, exp(-1.2* t)=%f\n", t, x[0], exp(- 1.2 * t)); + return x[0]; } -void adept_sincos(double inp, uint64_t iters); - static void enzyme_sincos(double inp, uint64_t iters) { { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); - double res = foobar(inp, iters); + double res = foobar(inp, iters); - gettimeofday(&end, NULL); - printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res); + gettimeofday(&end, NULL); + printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res); } { - struct timeval start, end; - gettimeofday(&start, NULL); - - double res = foobar(inp, iters); + struct timeval start, end; + gettimeofday(&start, NULL); - gettimeofday(&end, NULL); - printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res); - } - - { - struct timeval start, end; - gettimeofday(&start, NULL); - double res2; - - res2 = __enzyme_autodiff(foobar, inp, iters); - - gettimeofday(&end, NULL); - printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), res2); +#ifdef LOGGING + double res = __enzyme_autodiff(foobar, inp, iters); +#else + double res = foobar(inp, iters); +#endif } } -int main(int argc, char** argv) { +int main(int argc, char **argv) { +#ifdef LOGGING + initializeLogger(); +#endif - int max_iters = atoi(argv[1]) ; + int max_iters = atoi(argv[1]); double inp = 2.1; - for(int iters=max_iters/20; iters<=max_iters; iters+=max_iters/20) { + for (int iters = max_iters / 20; iters <= max_iters; + iters += max_iters / 20) { printf("iters=%d\n", iters); - adept_sincos(inp, iters); enzyme_sincos(inp, iters); } + +#ifdef LOGGING + printLogger(); + destroyLogger(); +#endif } diff --git a/enzyme/benchmarks/lit.site.cfg.py.in b/enzyme/benchmarks/lit.site.cfg.py.in index adfac5c63608..9fa7c32fea9a 100644 --- a/enzyme/benchmarks/lit.site.cfg.py.in +++ b/enzyme/benchmarks/lit.site.cfg.py.in @@ -93,11 +93,18 @@ config.substitutions.append(('%loadBC', '' )) config.substitutions.append(('%BClibdir', '@ENZYME_SOURCE_DIR@/bclib/')) -newPM = (' -fpass-plugin=@ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext +ldPM = (((" -fno-experimental-new-pass-manager" if int(config.llvm_ver) < 14 else "-flegacy-pass-manager") if int(config.llvm_ver) >= 13 else "") + ' -Xclang -load -Xclang @ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext) +newPM = ((" -fexperimental-new-pass-manager" if int(config.llvm_ver) < 13 else "") + + ' -fpass-plugin=@ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext + + ' -Xclang -load -Xclang @ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext) + +if len("@ENZYME_BINARY_DIR@") == 0: + oldPM = ((" -fno-experimental-new-pass-manager" if int(config.llvm_ver) < 14 else "-flegacy-pass-manager") if int(config.llvm_ver) >= 13 else "") + newPM = (" -fexperimental-new-pass-manager" if int(config.llvm_ver) < 13 else "") +config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < 15 else newPM)) config.substitutions.append(('%newLoadClangEnzyme', newPM)) -config.substitutions.append(('%LoadClangEnzyme', newPM)) # Let the main config do the real work. lit_config.load_config(config, "@ENZYME_SOURCE_DIR@/benchmarks/lit.cfg.py") diff --git a/enzyme/test/Enzyme/CMakeLists.txt b/enzyme/test/Enzyme/CMakeLists.txt index 7998da45a735..80a107db5a30 100644 --- a/enzyme/test/Enzyme/CMakeLists.txt +++ b/enzyme/test/Enzyme/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(ForwardModeSplit) add_subdirectory(ForwardModeVector) add_subdirectory(BatchMode) add_subdirectory(ProbProg) +add_subdirectory(FPOpt) add_subdirectory(JLSimplify) # Run regression and unit tests diff --git a/enzyme/test/Enzyme/FPOpt/CMakeLists.txt b/enzyme/test/Enzyme/FPOpt/CMakeLists.txt new file mode 100644 index 000000000000..2e077bf27861 --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-fpopt "Running enzyme floating-point optimization regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-fpopt PROPERTIES FOLDER "Tests") + +#add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} + # DEPENDS ${ENZYME_TEST_DEPS} +#) diff --git a/enzyme/test/Enzyme/FPOpt/add.ll b/enzyme/test/Enzyme/FPOpt/add.ll new file mode 100644 index 000000000000..1a02c75bbdb6 --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/add.ll @@ -0,0 +1,14 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + ret double %0 +} + +; CHECK: define double @tester(double %x, double %y) +; CHECK: entry: +; CHECK-NEXT: %[[i0:.+]] = fadd fast double %x, %y +; CHECK-NEXT: ret double %[[i0]] diff --git a/enzyme/test/Enzyme/FPOpt/cancel1.ll b/enzyme/test/Enzyme/FPOpt/cancel1.ll new file mode 100644 index 000000000000..ec0550ec1c14 --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/cancel1.ll @@ -0,0 +1,14 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + %1 = fsub fast double %0, %x + ret double %1 +} + +; CHECK: define double @tester(double %x, double %y) +; CHECK: entry: +; CHECK-NEXT: ret double %y diff --git a/enzyme/test/Enzyme/FPOpt/if.ll b/enzyme/test/Enzyme/FPOpt/if.ll new file mode 100644 index 000000000000..dfe19636abbd --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/if.ll @@ -0,0 +1,67 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %a, double %b, double %c) { +entry: + %0 = fmul double %a, %c + %1 = fmul double 4.000000e+00, %0 + %2 = fmul double %b, %b + %3 = fsub double %2, %1 + %4 = call double @llvm.sqrt.f64(double %3) + %5 = fsub double 0.000000e+00, %b + %6 = fsub double %5, %4 + %7 = fmul double 2.000000e+00, %a + %8 = fdiv double %6, %7 + ret double %8 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sqrt.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.fmuladd.f64(double, double, double) + +; CHECK: define double @tester(double %a, double %b, double %c) +; CHECK: entry: +; CHECK-NEXT: %[[i0:.+]] = fcmp fast ole double %b, -6.800000e+00 +; CHECK-NEXT: br i1 %[[i0]], label %[[i1:.+]], label %[[i2:.+]] + +; CHECK: [[i1]]: +; CHECK-NEXT: %[[i3:.+]] = fdiv fast double %c, %b +; CHECK-NEXT: %[[i4:.+]] = fneg fast double %c +; CHECK-NEXT: %[[i5:.+]] = fdiv fast double %b, %[[i4]] +; CHECK-NEXT: %[[i6:.+]] = fdiv fast double %[[i3]], %[[i5]] +; CHECK-NEXT: %[[i7:.+]] = fmul fast double %a, %[[i6]] +; CHECK-NEXT: %[[i8:.+]] = fsub fast double %[[i7]], %c +; CHECK-NEXT: %[[i9:.+]] = fdiv fast double %[[i8]], %b +; CHECK-NEXT: br label %[[i10:.+]] + +; CHECK: [[i2]]: +; CHECK-NEXT: %[[i11:.+]] = fcmp fast ole double %b, 0x52682111B1052222 +; CHECK-NEXT: br i1 %[[i11]], label %[[i12:.+]], label %[[i13:.+]] + +; CHECK: [[i12]]: +; CHECK-NEXT: %[[i14:.+]] = fmul fast double %c, -4.000000e+00 +; CHECK-NEXT: %[[i15:.+]] = fmul fast double %b, %b +; CHECK-NEXT: %[[i16:.+]] = call fast double @llvm.fmuladd.f64(double %a, double %[[i14]], double %[[i15]]) +; CHECK-NEXT: %[[i17:.+]] = call fast double @llvm.sqrt.f64(double %[[i16]]) +; CHECK-NEXT: %[[i18:.+]] = fadd fast double %b, %[[i17]] +; CHECK-NEXT: %[[i19:.+]] = fneg fast double %a +; CHECK-NEXT: %[[i20:.+]] = fmul fast double 2.000000e+00, %[[i19]] +; CHECK-NEXT: %[[i21:.+]] = fdiv fast double %[[i18]], %[[i20]] +; CHECK-NEXT: br label %[[i22:.+]] + +; CHECK: [[i13]]: +; CHECK-NEXT: %[[i23:.+]] = fdiv fast double %c, %b +; CHECK-NEXT: %[[i24:.+]] = fdiv fast double %b, %a +; CHECK-NEXT: %[[i25:.+]] = fsub fast double %[[i23]], %[[i24]] +; CHECK-NEXT: br label %[[i26:.+]] + +; CHECK: [[i26]]: +; CHECK-NEXT: %[[i27:.+]] = phi fast double [ %[[i21]], %[[i12]] ], [ %[[i25]], %[[i13]] ] +; CHECK-NEXT: br label %[[i28:.+]] + +; CHECK: [[i28]]: +; CHECK-NEXT: %[[i29:.+]] = phi fast double [ %[[i9]], %[[i1]] ], [ %[[i27]], %[[i22]] ] +; CHECK-NEXT: ret double %[[i29]] diff --git a/enzyme/test/Enzyme/FPOpt/pt1.ll b/enzyme/test/Enzyme/FPOpt/pt1.ll new file mode 100644 index 000000000000..8e8f1f6c3412 --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/pt1.ll @@ -0,0 +1,26 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -fpopt-enable-herbie=0 -fpopt-enable-pt=1 -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -fpopt-enable-herbie=0 -fpopt-enable-pt=1 -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = call fast double @llvm.cos.f64(double %x) + %1 = fmul fast double %0, %0 + %2 = fsub fast double 1.000000e+00, %1 + ret double %2 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.cos.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sin.f64(double) + +; CHECK: define double @tester(double %x) +; CHECK: entry: +; CHECK-NEXT: %[[i0:.+]] = fptrunc double %x to half +; CHECK-NEXT: %[[i1:.+]] = call fast half @llvm.cos.f16(half %[[i0]]) +; CHECK-NEXT: %[[i2:.+]] = fmul fast half %[[i1]], %[[i1]] +; CHECK-NEXT: %[[i3:.+]] = fsub fast half 0xH3C00, %[[i2]] +; CHECK-NEXT: %[[i4:.+]] = fpext half %[[i3]] to double +; CHECK-NEXT: ret double %[[i4]] diff --git a/enzyme/test/Enzyme/FPOpt/reassociate1.ll b/enzyme/test/Enzyme/FPOpt/reassociate1.ll new file mode 100644 index 000000000000..514ab4de8442 --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/reassociate1.ll @@ -0,0 +1,16 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + %1 = fadd fast double %0, %x + ret double %1 +} + +; CHECK: define double @tester(double %x, double %y) +; CHECK: entry: +; CHECK-NEXT: %[[i0:.+]] = call fast double @llvm.fmuladd.f64(double %x, double 2.000000e+00, double %y) +; CHECK-NEXT: ret double %[[i0]] + diff --git a/enzyme/test/Enzyme/FPOpt/trig1.ll b/enzyme/test/Enzyme/FPOpt/trig1.ll new file mode 100644 index 000000000000..37a1ea637c35 --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/trig1.ll @@ -0,0 +1,23 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = call fast double @llvm.cos.f64(double %x) + %1 = fmul fast double %0, %0 + %2 = fsub fast double 1.000000e+00, %1 + ret double %2 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.cos.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sin.f64(double) + +; CHECK: define double @tester(double %x) +; CHECK: entry: +; CHECK-NEXT: %[[i0:.+]] = call fast double @llvm.sin.f64(double %x) +; CHECK-NEXT: %[[i1:.+]] = call fast double @llvm.pow.f64(double %[[i0]], double 2.000000e+00) +; CHECK-NEXT: ret double %[[i1]] diff --git a/enzyme/test/Enzyme/FPOpt/trig2.ll b/enzyme/test/Enzyme/FPOpt/trig2.ll new file mode 100644 index 000000000000..0a05e20d5cad --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/trig2.ll @@ -0,0 +1,89 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define void @tester() { +entry: + %arr = alloca double, i64 5, align 8 + + ; Compute addresses for array elements + %ptr1 = getelementptr inbounds double, double* %arr, i64 0 + %ptr2 = getelementptr inbounds double, double* %arr, i64 1 + %ptr3 = getelementptr inbounds double, double* %arr, i64 2 + %ptr4 = getelementptr inbounds double, double* %arr, i64 3 + %ptr5 = getelementptr inbounds double, double* %arr, i64 4 + + ; Store constants into the array + store double 3.141592653589793, double* %ptr1, align 8 + store double 2.718281828459045, double* %ptr2, align 8 + store double 1.4142135623730951, double* %ptr3, align 8 + store double 0.5772156649015329, double* %ptr4, align 8 + store double 0.6931471805599453, double* %ptr5, align 8 + + ; Load constants from the array + %val1 = load double, double* %ptr1, align 8 + %val2 = load double, double* %ptr2, align 8 + %val3 = load double, double* %ptr3, align 8 + %val4 = load double, double* %ptr4, align 8 + %val5 = load double, double* %ptr5, align 8 + + ; Perform computations + %cos_val = call fast double @llvm.cos.f64(double %val1) + %sin_val = call fast double @llvm.sin.f64(double %val2) + %exp_val = call fast double @llvm.exp.f64(double %val3) + %log_val = call fast double @llvm.log.f64(double %val4) + %sum_val = fadd fast double %cos_val, %sin_val + + ; Store results back to the array + store double %cos_val, double* %ptr1, align 8 + store double %sin_val, double* %ptr2, align 8 + store double %exp_val, double* %ptr3, align 8 + store double %log_val, double* %ptr4, align 8 + store double %sum_val, double* %ptr5, align 8 + + ret void +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.log.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.exp.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.cos.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sin.f64(double) + +; CHECK: define void @tester() +; CHECK: entry: +; CHECK-NEXT: %arr = alloca double, i64 5, align 8 +; CHECK-NEXT: %ptr1 = getelementptr inbounds double, double* %arr, i64 0 +; CHECK-NEXT: %ptr2 = getelementptr inbounds double, double* %arr, i64 1 +; CHECK-NEXT: %ptr3 = getelementptr inbounds double, double* %arr, i64 2 +; CHECK-NEXT: %ptr4 = getelementptr inbounds double, double* %arr, i64 3 +; CHECK-NEXT: %ptr5 = getelementptr inbounds double, double* %arr, i64 4 +; CHECK-NEXT: store double 0x400921FB54442D18, double* %ptr1, align 8 +; CHECK-NEXT: store double 0x4005BF0A8B145769, double* %ptr2, align 8 +; CHECK-NEXT: store double 0x3FF6A09E667F3BCD, double* %ptr3, align 8 +; CHECK-NEXT: store double 0x3FE2788CFC6FB619, double* %ptr4, align 8 +; CHECK-NEXT: store double 0x3FE62E42FEFA39EF, double* %ptr5, align 8 +; CHECK-NEXT: %val1 = load double, double* %ptr1, align 8 +; CHECK-NEXT: %val2 = load double, double* %ptr2, align 8 +; CHECK-NEXT: %val3 = load double, double* %ptr3, align 8 +; CHECK-NEXT: %val4 = load double, double* %ptr4, align 8 +; CHECK-NEXT: %val5 = load double, double* %ptr5, align 8 +; CHECK-NEXT: %[[i0:.+]] = call fast double @llvm.cos.f64(double %val1) +; CHECK-NEXT: %[[i1:.+]] = call fast double @llvm.sin.f64(double %val2) +; CHECK-NEXT: %exp_val = call fast double @llvm.exp.f64(double %val3) +; CHECK-NEXT: %log_val = call fast double @llvm.log.f64(double %val4) +; CHECK-NEXT: %[[i2:.+]] = call fast double @llvm.cos.f64(double %val1) +; CHECK-NEXT: %[[i3:.+]] = call fast double @llvm.sin.f64(double %val2) +; CHECK-NEXT: %[[i4:.+]] = fadd fast double %[[i2]], %[[i3]] +; CHECK-NEXT: store double %[[i0]], double* %ptr1, align 8 +; CHECK-NEXT: store double %[[i1]], double* %ptr2, align 8 +; CHECK-NEXT: store double %exp_val, double* %ptr3, align 8 +; CHECK-NEXT: store double %log_val, double* %ptr4, align 8 +; CHECK-NEXT: store double %[[i4]], double* %ptr5, align 8 +; CHECK-NEXT: ret void \ No newline at end of file diff --git a/enzyme/test/Enzyme/FPOpt/trig3.ll b/enzyme/test/Enzyme/FPOpt/trig3.ll new file mode 100644 index 000000000000..6ad8445126e1 --- /dev/null +++ b/enzyme/test/Enzyme/FPOpt/trig3.ll @@ -0,0 +1,23 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -fp-opt -enzyme-print-fpopt -enzyme-print-herbie -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="fp-opt" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = call fast float @llvm.cos.f32(float %x) + %1 = fmul fast float %0, %0 + %2 = fsub fast float 1.000000e+00, %1 + ret float %2 +} + +; Function Attrs: nounwind readnone speculatable +declare float @llvm.cos.f32(float) + +; Function Attrs: nounwind readnone speculatable +declare float @llvm.sin.f32(float) + +; CHECK: define float @tester(float %x) +; CHECK: entry: +; CHECK-NEXT: %[[i0:.+]] = call fast float @llvm.sin.f32(float %x) +; CHECK-NEXT: %[[i1:.+]] = call fast float @llvm.pow.f32(float %[[i0]], float 2.000000e+00) +; CHECK-NEXT: ret float %[[i1]] diff --git a/enzyme/test/Enzyme/ForwardError/add.ll b/enzyme/test/Enzyme/ForwardError/add.ll index 265759867f7a..d75b5646df65 100644 --- a/enzyme/test/Enzyme/ForwardError/add.ll +++ b/enzyme/test/Enzyme/ForwardError/add.ll @@ -18,7 +18,7 @@ entry: declare double @__enzyme_error_estimate(double (double, double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'", double %y, double %"y'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = fadd double %x, %y ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Enzyme/ForwardError/cos.ll b/enzyme/test/Enzyme/ForwardError/cos.ll index 3d75b9115e2c..62f8765c7980 100644 --- a/enzyme/test/Enzyme/ForwardError/cos.ll +++ b/enzyme/test/Enzyme/ForwardError/cos.ll @@ -23,21 +23,33 @@ declare double @llvm.sin.f64(double) ; Function Attrs: nounwind declare double @__enzyme_error_estimate(double (double)*, ...) +; Function Attrs: mustprogress noinline optnone ssp uwtable +declare void @enzymeLogValue(i8* noundef %id, double noundef %res, i32 noundef %numOperands, double* noundef %operands) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; Function Attrs: mustprogress noinline optnone ssp uwtable +declare void @enzymeLogError(i8* noundef %id, double noundef %err) + + +; CHECK: define internal double @fwderrtester(double %x, double %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %[[i0:.+]] = tail call fast double @llvm.cos.f64(double %x) -; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x -; CHECK-NEXT: %[[i2:.+]] = fdiv fast double %[[i1]], %[[i0]] -; CHECK-NEXT: %[[i3:.+]] = call fast double @llvm.sin.f64(double %x) -; CHECK-NEXT: %[[i4:.+]] = fneg fast double %[[i3]] -; CHECK-NEXT: %[[i5:.+]] = fmul fast double %[[i2]], %[[i4]] -; CHECK-NEXT: %[[i6:.+]] = call fast double @llvm.fabs.f64(double %[[i5]]) -; CHECK-NEXT: %[[i7:.+]] = bitcast double %[[i0]] to i64 -; CHECK-NEXT: %[[i8:.+]] = xor i64 %[[i7]], 1 -; CHECK-NEXT: %[[i9:.+]] = bitcast i64 %[[i8]] to double -; CHECK-NEXT: %[[i10:.+]] = fsub fast double %[[i0]], %[[i9]] -; CHECK-NEXT: %[[i11:.+]] = call fast double @llvm.fabs.f64(double %[[i10]]) -; CHECK-NEXT: %[[i12:.+]] = call fast double @llvm.maxnum.f64(double %[[i11]], double %[[i6]]) -; CHECK-NEXT: ret double %[[i12]] -; CHECK-NEXT: } \ No newline at end of file +; CHECK-NEXT: %[[i0:.+]] = alloca [1 x double], align 8 +; CHECK-NEXT: %[[i1:.+]] = tail call fast double @llvm.cos.f64(double %x) +; CHECK-NEXT: %[[i4:.+]] = fmul fast double %"x'", %x +; CHECK-NEXT: %[[i5:.+]] = fdiv fast double %[[i4]], %[[i1]] +; CHECK-NEXT: %[[i6:.+]] = call fast double @llvm.sin.f64(double %x) +; CHECK-NEXT: %[[i7:.+]] = fneg fast double %[[i6]] +; CHECK-NEXT: %[[i8:.+]] = fmul fast double %[[i5]], %[[i7]] +; CHECK-NEXT: %[[i9:.+]] = call fast double @llvm.fabs.f64(double %[[i8]]) +; CHECK-NEXT: %[[i10:.+]] = bitcast double %[[i1]] to i64 +; CHECK-NEXT: %[[i11:.+]] = xor i64 %[[i10]], 1 +; CHECK-NEXT: %[[i12:.+]] = bitcast i64 %[[i11]] to double +; CHECK-NEXT: %[[i13:.+]] = fsub fast double %[[i1]], %[[i12]] +; CHECK-NEXT: %[[i14:.+]] = call fast double @llvm.fabs.f64(double %[[i13]]) +; CHECK-NEXT: %[[i15:.+]] = call fast double @llvm.maxnum.f64(double %[[i14]], double %[[i9]]) +; CHECK-NEXT: call void @enzymeLogError(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @1, i32 0, i32 0), double %[[i15]]) +; CHECK-NEXT: %[[i2:.+]] = getelementptr [1 x double], [1 x double]* %[[i0]], i32 0, i32 0 +; CHECK-NEXT: store double %x, double* %[[i2]], align 8 +; CHECK-NEXT: %[[i3:.+]] = getelementptr [1 x double], [1 x double]* %[[i0]], i32 0, i32 0 +; CHECK-NEXT: call void @enzymeLogValue(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @0, i32 0, i32 0), double %1, i32 1, double* %[[i3]]) +; CHECK-NEXT: ret double %[[i15]] +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardError/div.ll b/enzyme/test/Enzyme/ForwardError/div.ll index ea8d9a6ab39b..87af99012ae3 100644 --- a/enzyme/test/Enzyme/ForwardError/div.ll +++ b/enzyme/test/Enzyme/ForwardError/div.ll @@ -18,7 +18,7 @@ entry: declare double @__enzyme_error_estimate(double (double, double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'", double %y, double %"y'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = fdiv double %x, %y ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Enzyme/ReverseMode/addLog.ll b/enzyme/test/Enzyme/ReverseMode/addLog.ll new file mode 100644 index 000000000000..b2af071cb0d1 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/addLog.ll @@ -0,0 +1,41 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, double %x, double %y) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double, double)*, ...) + +; Function Attrs: mustprogress noinline optnone ssp uwtable +declare void @enzymeLogValue(i8* noundef %id, double noundef %res, i32 noundef %numOperands, double* noundef %operands) + +; Function Attrs: mustprogress noinline optnone ssp uwtable +declare void @enzymeLogGrad(i8* noundef %id, double noundef %grad) + + +; CHECK: define internal {{(dso_local )?}}{ double, double } @diffetester(double %x, double %y, double %[[differet:.+]]) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[i0:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[i1:.+]] = fadd fast double %x, %y +; CHECK-NEXT: %[[i2:.+]] = getelementptr [2 x double], [2 x double]* %[[i0]], i32 0, i32 0 +; CHECK-NEXT: store double %x, double* %[[i2]], align 8 +; CHECK-NEXT: %[[i3:.+]] = getelementptr [2 x double], [2 x double]* %[[i0]], i32 0, i32 1 +; CHECK-NEXT: store double %y, double* %[[i3]], align 8 +; CHECK-NEXT: %[[i4:.+]] = getelementptr [2 x double], [2 x double]* %[[i0]], i32 0, i32 0 +; CHECK-NEXT: call void @enzymeLogValue(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @0, i32 0, i32 0), double %[[i1]], i32 2, double* %[[i4]]) +; CHECK-NEXT: call void @enzymeLogGrad(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @1, i32 0, i32 0), double %[[differet]]) +; CHECK-NEXT: %[[i5:.+]] = insertvalue { double, double } undef, double %[[differet]], 0 +; CHECK-NEXT: %[[i6:.+]] = insertvalue { double, double } %[[i5]], double %[[differet]], 1 +; CHECK-NEXT: ret { double, double } %[[i6]] +; CHECK-NEXT: } diff --git a/enzyme/test/Integration/CMakeLists.txt b/enzyme/test/Integration/CMakeLists.txt index f76a45cfda30..237e4acbb3f2 100644 --- a/enzyme/test/Integration/CMakeLists.txt +++ b/enzyme/test/Integration/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(CppSugar) add_subdirectory(ForwardMode) add_subdirectory(ForwardError) add_subdirectory(ForwardModeVector) +add_subdirectory(FPOpt) add_subdirectory(ReverseMode) add_subdirectory(BatchMode) add_subdirectory(Sparse) diff --git a/enzyme/test/Integration/FPOpt/CMakeLists.txt b/enzyme/test/Integration/FPOpt/CMakeLists.txt new file mode 100644 index 000000000000..e3757b9ce962 --- /dev/null +++ b/enzyme/test/Integration/FPOpt/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-integration-fpopt "Running enzyme FPOpt integration tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR} + ARGS -v +) + +set_target_properties(check-enzyme-integration-fpopt PROPERTIES FOLDER "Tests") + +#add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} + # DEPENDS ${ENZYME_TEST_DEPS} +#) diff --git a/enzyme/test/Integration/FPOpt/root_solve1.cpp b/enzyme/test/Integration/FPOpt/root_solve1.cpp new file mode 100644 index 000000000000..8cd75f43361b --- /dev/null +++ b/enzyme/test/Integration/FPOpt/root_solve1.cpp @@ -0,0 +1,30 @@ +// RUN: %clang++ -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - +// RUN: %clang++ -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - +// RUN: %clang++ -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - +// RUN: %clang++ -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - + +#include "../test_utils.h" + +#include + +double fun(double a, double b, double c) { + double discriminant = b * b - 4 * a * c; + double sqrt_discriminant = sqrt(discriminant); + double numerator = -b - sqrt_discriminant; + double result = numerator / (2 * a); + return result; +} + +int main() { + // x^2 - 3x + 2 = 0 --> x1 = 1 (computed), x2 = 2 + double res1 = fun(1, -3, 2); + printf("res1 = %.18e\n", res1); + APPROX_EQ(res1, 1.0, 1e-4); + + // x^2 - 5x + 6 = 0 --> x1 = 2 (computed), x2 = 3 + double res2 = fun(1, -5, 6); + printf("res2 = %.18e\n", res2); + APPROX_EQ(res2, 2.0, 1e-4); + + return 0; +} diff --git a/enzyme/test/Integration/FPOpt/root_solve2.cpp b/enzyme/test/Integration/FPOpt/root_solve2.cpp new file mode 100644 index 000000000000..925c769d0185 --- /dev/null +++ b/enzyme/test/Integration/FPOpt/root_solve2.cpp @@ -0,0 +1,33 @@ +// RUN: %clang++ -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - +// RUN: %clang++ -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - +// RUN: %clang++ -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - +// RUN: %clang++ -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %fpopt -enzyme-print-herbie -enzyme-print-fpopt -fpopt-target-func-regex=fun -S | %lli - + +#include "../test_utils.h" + +#include + +double fun(double a, double b, double c) { + double discriminant = b * b - 4 * a * c; + printf("discriminant = %.18e\n", discriminant); + double sqrt_discriminant = sqrt(discriminant); + printf("sqrt_discriminant = %.18e\n", sqrt_discriminant); + double numerator = -b - sqrt_discriminant; + printf("numerator = %.18e\n", numerator); + double result = numerator / (2 * a); + return result; +} + +int main() { + // x^2 - 3x + 2 = 0 --> x1 = 1 (computed), x2 = 2 + double res1 = fun(1, -3, 2); + printf("res1 = %.18e\n", res1); + APPROX_EQ(res1, 1.0, 1e-4); + + // x^2 - 5x + 6 = 0 --> x1 = 2 (computed), x2 = 3 + double res2 = fun(1, -5, 6); + printf("res2 = %.18e\n", res2); + APPROX_EQ(res2, 2.0, 1e-4); + + return 0; +} diff --git a/enzyme/test/Integration/ForwardError/binops.c b/enzyme/test/Integration/ForwardError/binops1.c similarity index 70% rename from enzyme/test/Integration/ForwardError/binops.c rename to enzyme/test/Integration/ForwardError/binops1.c index 2770060575ce..ec35d8cdff2f 100644 --- a/enzyme/test/Integration/ForwardError/binops.c +++ b/enzyme/test/Integration/ForwardError/binops1.c @@ -11,15 +11,21 @@ double fabs(double); extern double __enzyme_error_estimate(void *, ...); +int valueLogCount = 0; int errorLogCount = 0; -void enzymeLogError(double res, double err, const char *opcodeName, - const char *calleeName, const char *moduleName, - const char *functionName, const char *blockName) { +void enzymeLogValue(const char *id, double res, unsigned numOperands, + double *operands) { + ++valueLogCount; + printf("Id = %s, Res = %.18e\n", id, res); + for (int i = 0; i < numOperands; ++i) { + printf("\tOperand[%d] = %.18e\n", i, operands[i]); + } +} + +void enzymeLogError(const char *id, double err) { ++errorLogCount; - printf("Res = %e, Error = %e, Op = %s, Callee = %s, Module = %s, Function = " - "%s, BasicBlock = %s\n", - res, err, opcodeName, calleeName, moduleName, functionName, blockName); + printf("Id = %s, Err = %.18e\n", id, err); } // An example from https://dl.acm.org/doi/10.1145/3371128 @@ -28,7 +34,7 @@ double fun(double x) { double v2 = 1 - v1; double v3 = x * x; double v4 = v2 / v3; - double v5 = sin(v4); // Inactive -- logger is not invoked. + double v5 = sin(v4); printf("v1 = %.18e, v2 = %.18e, v3 = %.18e, v4 = %.18e, v5 = %.18e\n", v1, v2, v3, v4, v5); @@ -42,5 +48,6 @@ int main() { printf("res = %.18e, abs error = %.18e, rel error = %.18e\n", res, error, fabs(error / res)); APPROX_EQ(error, 2.2222222222e-2, 1e-4); + TEST_EQ(valueLogCount, 4); // TODO: should be 5 TEST_EQ(errorLogCount, 4); } diff --git a/enzyme/test/Integration/ForwardError/binops2.cpp b/enzyme/test/Integration/ForwardError/binops2.cpp new file mode 100644 index 000000000000..546cdc9fe807 --- /dev/null +++ b/enzyme/test/Integration/ForwardError/binops2.cpp @@ -0,0 +1,51 @@ +// RUN: %clang++ -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - + +#include "../test_utils.h" + +#include + +extern double __enzyme_error_estimate(void *, ...); + +int valueLogCount = 0; +int errorLogCount = 0; + +void enzymeLogValue(const char *id, double res, unsigned numOperands, + double *operands) { + ++valueLogCount; + printf("Id = %s, Res = %f\n", id, res); + for (unsigned i = 0; i < numOperands; ++i) { + printf("\tOperand[%u] = %f\n", i, operands[i]); + } +} + +void enzymeLogError(const char *id, double err) { + ++errorLogCount; + printf("Id = %s, Err = %f\n", id, err); +} + +// An example from https://dl.acm.org/doi/10.1145/3371128 +double fun(double x) { + double v1 = cos(x); + double v2 = 1 - v1; + double v3 = x * x; + double v4 = v2 / v3; + double v5 = sin(v4); // Inactive -- logger is not invoked. + + printf("v1 = %f, v2 = %f, v3 = %f, v4 = %f, v5 = %f\n", v1, v2, v3, v4, v5); + + return v4; +} + +int main() { + double res = fun(1e-7); + double error = __enzyme_error_estimate((void *)fun, 1e-7, 0.0); + printf("res = %f, abs error = %f, rel error = %f\n", res, error, + fabs(error / res)); + + APPROX_EQ(error, 2.2222222222e-2, 1e-4); + TEST_EQ(valueLogCount, 4); // TODO: should be 5 + TEST_EQ(errorLogCount, 4); +} diff --git a/enzyme/test/Integration/ForwardError/binops3.cpp b/enzyme/test/Integration/ForwardError/binops3.cpp new file mode 100644 index 000000000000..241c3aea78f5 --- /dev/null +++ b/enzyme/test/Integration/ForwardError/binops3.cpp @@ -0,0 +1,32 @@ +// RUN: %clang++ -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - + +#include "../test_utils.h" + +#include + +extern double __enzyme_error_estimate(void *, ...); + +// An example from https://dl.acm.org/doi/10.1145/3371128 +double fun(double x) { + double v1 = cos(x); + double v2 = 1 - v1; + double v3 = x * x; + double v4 = v2 / v3; + + printf("v1 = %.18e, v2 = %.18e, v3 = %.18e, v4 = %.18e\n", v1, v2, v3, v4); + + return v4; +} + +int main() { + double res = fun(1e-7); + __enzyme_error_estimate((void *)fun, 2e-7, 0.0); + __enzyme_error_estimate((void *)fun, 7e-7, 0.0); + double error = __enzyme_error_estimate((void *)fun, 1e-7, 0.0); + printf("res = %.18e, abs error = %.18e, rel error = %.18e\n", res, error, + fabs(error / res)); + APPROX_EQ(error, 2.2222222222e-2, 1e-4); +} diff --git a/enzyme/test/Integration/ForwardError/fp-logger.h b/enzyme/test/Integration/ForwardError/fp-logger.h new file mode 100644 index 000000000000..3af9e69827e8 --- /dev/null +++ b/enzyme/test/Integration/ForwardError/fp-logger.h @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +class ValueInfo { +public: + double minRes = std::numeric_limits::max(); + double maxRes = std::numeric_limits::lowest(); + std::vector minOperands; + std::vector maxOperands; + unsigned executions = 0; + + void update(double res, const double *operands, unsigned numOperands) { + minRes = std::min(minRes, res); + maxRes = std::max(maxRes, res); + if (minOperands.empty()) { + minOperands.resize(numOperands, std::numeric_limits::max()); + maxOperands.resize(numOperands, std::numeric_limits::lowest()); + } + for (unsigned i = 0; i < numOperands; ++i) { + minOperands[i] = std::min(minOperands[i], operands[i]); + maxOperands[i] = std::max(maxOperands[i], operands[i]); + } + ++executions; + } +}; + +class ErrorInfo { +public: + double minErr = std::numeric_limits::max(); + double maxErr = std::numeric_limits::lowest(); + + void update(double err) { + minErr = std::min(minErr, err); + maxErr = std::max(maxErr, err); + } +}; + +class GradInfo { +public: + double grad = 0.0; + + void update(double grad) { this->grad = grad; } +}; + +class Logger { +private: + std::unordered_map valueInfo; + std::unordered_map errorInfo; + std::unordered_map gradInfo; + +public: + void updateValue(const std::string &id, double res, unsigned numOperands, + const double *operands) { + auto &info = valueInfo.emplace(id, ValueInfo()).first->second; + info.update(res, operands, numOperands); + } + + void updateError(const std::string &id, double err) { + auto &info = errorInfo.emplace(id, ErrorInfo()).first->second; + info.update(err); + } + + void updateGrad(const std::string &id, double grad) { + auto &info = gradInfo.emplace(id, GradInfo()).first->second; + info.update(grad); + } + + void print() const { + std::cout << std::scientific + << std::setprecision(std::numeric_limits::max_digits10); + + for (const auto &pair : valueInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Value:" << id << "\n"; + std::cout << "\tMinRes = " << info.minRes << "\n"; + std::cout << "\tMaxRes = " << info.maxRes << "\n"; + std::cout << "\tExecutions = " << info.executions << "\n"; + for (unsigned i = 0; i < info.minOperands.size(); ++i) { + std::cout << "\tOperand[" << i << "] = [" << info.minOperands[i] << ", " + << info.maxOperands[i] << "]\n"; + } + } + + for (const auto &pair : errorInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Error:" << id << "\n"; + std::cout << "\tMinErr = " << info.minErr << "\n"; + std::cout << "\tMaxErr = " << info.maxErr << "\n"; + } + + for (const auto &pair : gradInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Grad:" << id << "\n"; + std::cout << "\tGrad = " << info.grad << "\n"; + } + } +}; + +static Logger *logger = nullptr; + +void initializeLogger() { logger = new Logger(); } + +void destroyLogger() { + delete logger; + logger = nullptr; +} + +void printLogger() { logger->print(); } + +void enzymeLogError(const char *id, double err) { + assert(logger && "Logger is not initialized"); + logger->updateError(id, err); +} + +void enzymeLogGrad(const char *id, double grad) { + assert(logger && "Logger is not initialized"); + logger->updateGrad(id, grad); +} + +void enzymeLogValue(const char *id, double res, unsigned numOperands, + double *operands) { + assert(logger && "Logger is not initialized"); + logger->updateValue(id, res, numOperands, operands); +} \ No newline at end of file diff --git a/enzyme/test/Integration/ReverseMode/logger.cpp b/enzyme/test/Integration/ReverseMode/logger.cpp new file mode 100644 index 000000000000..1f4deeac194b --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/logger.cpp @@ -0,0 +1,43 @@ +// RUN: %clang++ -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - + +#include "../test_utils.h" +#include + +extern double __enzyme_autodiff(void *, ...); + +int errorLogCount = 0; + +void enzymeLogGrad(double grad, const char *opcodeName, const char *calleeName, + const char *moduleName, const char *functionName, + unsigned blockIdx, unsigned instIdx) { + ++errorLogCount; + printf("Grad = %e, Op = %s, Callee = %s, Module = %s, Function = " + "%s, BlockIdx = %u, InstIdx = %u\n", + grad, opcodeName, calleeName, moduleName, functionName, blockIdx, + instIdx); +} + +double fun(double x) { + double v1 = x * 3; + double v2 = 1 - v1; + double v3 = x * x; + double v4 = v2 / v3; + double v5 = v3 + v4; + + printf("v1 = %.18e, v2 = %.18e, v3 = %.18e, v4 = %.18e, v5 = %.18e\n", v1, v2, + v3, v4, v5); + + return v5; +} + +int main() { + double x = 2.0; + double res = fun(x); + double grad_x = __enzyme_autodiff((void *)fun, x); + printf("res = %.18e, grad = %.18e\n", res, grad_x); + APPROX_EQ(grad_x, 4.5, 1e-4); + TEST_EQ(errorLogCount, 5); +} diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 0cc5e6f28f38..975cd42723ba 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -95,6 +95,7 @@ config.substitutions.append(('%newLoadEnzyme', newPM)) config.substitutions.append(('%OPloadEnzyme', oldPMOP if int(config.llvm_ver) < 16 else newPMOP)) config.substitutions.append(('%OPnewLoadEnzyme', newPMOP)) config.substitutions.append(('%enzyme', ('-enzyme' if int(config.llvm_ver) < 16 else '-passes="enzyme"'))) +config.substitutions.append(('%fpopt', ('-fp-opt' if int(config.llvm_ver) < 16 else '-passes="fp-opt"'))) config.substitutions.append(('%simplifycfg', ("simplify-cfg" if int(config.llvm_ver) < 11 else "simplifycfg"))) config.substitutions.append(('%loopmssa', ("loop" if int(config.llvm_ver) < 11 else "loop-mssa"))) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index b6848cedcebe..b4ca5feb81eb 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1943,6 +1943,57 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, else os << " gutils->eraseIfUnused(" << origName << ");\n"; + if (intrinsic != MLIRDerivatives) { + os << " if (auto *logFunc = getLogFunction(" << origName + << ".getModule(), \"enzymeLogValue\")) {\n" + << " IRBuilder<> BuilderZ(&" << origName << ");\n" + << " getForwardBuilder(BuilderZ);\n" + << " std::string idStr = getLogIdentifier(" << origName << ");\n" + << " Value *idValue = " + "BuilderZ.CreateGlobalStringPtr(idStr);\n" + << " Value *origValue = " + "BuilderZ.CreateFPExt(gutils->getNewFromOriginal(&" + << origName << "), Type::getDoubleTy(" << origName + << ".getContext()));\n" + << " unsigned numOperands = isa(" << origName + << ") ? cast(" << origName << ").arg_size() : " << origName + << ".getNumOperands();\n" + << " Value *numOperandsValue = " + "ConstantInt::get(Type::getInt32Ty(" + << origName << ".getContext()), numOperands);\n" + << " auto operands = isa(" << origName + << ") ? cast(" << origName << ").args() : " << origName + << ".operands();\n" + << " ArrayType *operandArrayType = " + "ArrayType::get(Type::getDoubleTy(" + << origName << ".getContext()), numOperands);\n" + << " Value *operandArrayValue = " + "IRBuilder<>(gutils->inversionAllocs).CreateAlloca(" + "operandArrayType);\n" + << " for (auto operand : enumerate(operands)) {\n" + << " Value *operandValue = " + "BuilderZ.CreateFPExt(gutils->getNewFromOriginal(operand.value()), " + "Type::getDoubleTy(" + << origName << ".getContext()));\n" + << " Value *ptr = " + "BuilderZ.CreateGEP(operandArrayType, operandArrayValue, " + "{ConstantInt::get(Type::getInt32Ty(" + << origName << ".getContext()), 0), ConstantInt::get(Type::getInt32Ty(" + << origName << ".getContext()), operand.index())});\n" + << " BuilderZ.CreateStore(operandValue, ptr);\n" + << " }\n" + << " Value *operandPtrValue = " + "BuilderZ.CreateGEP(operandArrayType, operandArrayValue, " + "{ConstantInt::get(Type::getInt32Ty(" + << origName << ".getContext()), 0), ConstantInt::get(Type::getInt32Ty(" + << origName << ".getContext()), 0)});\n" + << " CallInst *logCallInst = BuilderZ.CreateCall(logFunc, " + << "{idValue, origValue, numOperandsValue, operandPtrValue});\n" + << " logCallInst->setDebugLoc(gutils->getNewFromOriginal(" + << origName << ".getDebugLoc()));\n" + << " }\n"; + } + if (intrinsic == MLIRDerivatives) { os << " if (gutils->isConstantInstruction(op))\n"; os << " return success();\n"; @@ -2239,87 +2290,19 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " assert(res);\n"; // Insert logging function call (optional) - os << " Function *logFunc = " << origName - << ".getModule()->getFunction(\"enzymeLogError\");\n"; - os << " if (logFunc) {\n" - << " std::string moduleName = " << origName - << ".getModule()->getModuleIdentifier() ;\n" - << " std::string functionName = " << origName - << ".getFunction()->getName().str();\n" - << " std::string blockName = " << origName - << ".getParent()->getName().str();\n" - << " int funcIdx = -1, blockIdx = -1, instIdx = -1;\n" - << " auto funcIt = std::find_if(" << origName - << ".getModule()->begin(), " << origName - << ".getModule()->end(),\n" - " [&](const auto& func) { return &func == " - << origName - << ".getFunction(); });\n" - " if (funcIt != " - << origName - << ".getModule()->end()) {\n" - " funcIdx = " - "std::distance(" - << origName << ".getModule()->begin(), funcIt);\n" - << " }\n" - << " auto blockIt = std::find_if(" << origName - << ".getFunction()->begin(), " << origName - << ".getFunction()->end(),\n" - " [&](const auto& block) { return &block == " - << origName - << ".getParent(); });\n" - " if (blockIt != " - << origName - << ".getFunction()->end()) {\n" - " blockIdx = std::distance(" - << origName << ".getFunction()->begin(), blockIt);\n" - << " }\n" - << " auto instIt = std::find_if(" << origName - << ".getParent()->begin(), " << origName - << ".getParent()->end(),\n" - " [&](const auto& curr) { return &curr == &" - << origName - << "; });\n" - " if (instIt != " - << origName - << ".getParent()->end()) {\n" - " instIdx = std::distance(" - << origName << ".getParent()->begin(), instIt);\n" - << " }\n" - << " Value *origValue = " - "Builder2.CreateFPExt(gutils->getNewFromOriginal(&" - << origName << "), Type::getDoubleTy(" << origName - << ".getContext()));\n" + os << " if (auto *logFunc = getLogFunction(" << origName + << ".getModule(), \"enzymeLogError\")) {\n" + << " std::string idStr = getLogIdentifier(" << origName + << ");\n" + << " Value *idValue = " + "BuilderZ.CreateGlobalStringPtr(idStr);\n" << " Value *errValue = Builder2.CreateFPExt(res, " "Type::getDoubleTy(" << origName << ".getContext()));\n" - << " std::string opcodeName = " << origName - << ".getOpcodeName();\n" - << " std::string calleeName = \"\";\n" - << " if (auto CI = dyn_cast(&" << origName - << ")) {\n" - << " if (Function *fn = CI->getCalledFunction()) {\n" - << " calleeName = fn->getName();\n" - << " } else {\n" - << " calleeName = \"\";\n" - << " }\n" - << " }\n" - << " Value *moduleNameValue = " - "Builder2.CreateGlobalStringPtr(moduleName);\n" - << " Value *functionNameValue = " - "Builder2.CreateGlobalStringPtr(functionName + \" (\" +" - "std::to_string(funcIdx) + \")\");\n" - << " Value *blockNameValue = " - "Builder2.CreateGlobalStringPtr(blockName + \" (\" +" - "std::to_string(blockIdx) + \")\");\n" - << " Value *opcodeNameValue = " - "Builder2.CreateGlobalStringPtr(opcodeName + \" (\" " - "+std::to_string(instIdx) + \")\");\n" - << " Value *calleeNameValue = " - "Builder2.CreateGlobalStringPtr(calleeName);\n" - << " Builder2.CreateCall(logFunc, {origValue, " - "errValue, opcodeNameValue, calleeNameValue, moduleNameValue, " - "functionNameValue, blockNameValue});\n" + << " CallInst *logCallInst = Builder2.CreateCall(logFunc, " + "{idValue, errValue});\n" + << " logCallInst->setDebugLoc(gutils->getNewFromOriginal(" + << origName << ".getDebugLoc()));\n" << " }\n"; os << " setDiffe(&" << origName << ", res, Builder2);\n"; @@ -2341,6 +2324,21 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); if (intrinsic != MLIRDerivatives) { + os << " if (auto *logFunc = getLogFunction(" << origName + << ".getModule(), \"enzymeLogGrad\")) {\n" + << " std::string idStr = getLogIdentifier(" << origName + << ");\n" + << " Value *idValue = " + "BuilderZ.CreateGlobalStringPtr(idStr);\n" + << " Value *diffValue = Builder2.CreateFPExt(dif, " + "Type::getDoubleTy(" + << origName << ".getContext()));\n" + << " CallInst *logCallInst = Builder2.CreateCall(logFunc, " + "{idValue, diffValue});\n" + << " logCallInst->setDebugLoc(gutils->getNewFromOriginal(" + << origName << ".getDebugLoc()));\n" + << " }\n"; + os << " auto found = gutils->invertedPointers.find(&(" << origName << "));\n"; os << " if (found != gutils->invertedPointers.end() && "