diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index c9a50b3712a6..2bb3ce8003fc 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -28,6 +28,9 @@ Return __enzyme_fwddiff(T...); namespace enzyme { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wmissing-braces" + struct nodiff{}; template @@ -45,7 +48,10 @@ namespace enzyme { template < typename T > struct Active{ T value; - Active(T &&v) : value(v) {} + Active(const T& v) : value(v) {}; + Active(T&& v) : value(v) {}; + Active(const Active&) = default; + Active(Active&&) = default; operator T&() { return value; } }; @@ -53,23 +59,47 @@ namespace enzyme { struct Duplicated{ T value; T shadow; - Duplicated(T &&v, T&& s) : value(v), shadow(s) {} + Duplicated(const T& v, const T& s) : value(v), shadow(s) {}; + Duplicated(T&& v, T&& s) : value(v), shadow(s) {}; + Duplicated(const Duplicated&) = default; + Duplicated(Duplicated&&) = default; }; template < typename T > struct DuplicatedNoNeed{ T value; T shadow; - DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {} + DuplicatedNoNeed(const T& v, const T& s) : value(v), shadow(s) {}; + DuplicatedNoNeed(T&& v, T&& s) : value(v), shadow(s) {}; + DuplicatedNoNeed(const DuplicatedNoNeed&) = default; + DuplicatedNoNeed(DuplicatedNoNeed&&) = default; }; template < typename T > struct Const{ T value; - Const(T &&v) : value(v) {} + Const(const T& v) : value(v) {}; + Const(T&& v) : value(v) {}; + Const(const Const&) = default; + Const(Const&&) = default; operator T&() { return value; } }; + // CTAD available in C++17 or later + #if __cplusplus >= 201703L + template < typename T > + Active(T) -> Active; + + template < typename T > + Const(T) -> Const; + + template < typename T > + Duplicated(T,T) -> Duplicated; + + template < typename T > + DuplicatedNoNeed(T,T) -> DuplicatedNoNeed; + #endif + template < typename T > struct type_info { static constexpr bool is_active = false; @@ -189,7 +219,52 @@ namespace enzyme { return enzyme::tuple{arg.value}; } + template < typename T > + __attribute__((always_inline)) + auto primal_args_nt(const enzyme::Duplicated & arg) { + return arg.value; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args_nt(const enzyme::DuplicatedNoNeed & arg) { + return arg.value; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args_nt(const enzyme::Active & arg) { + return arg.value; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args_nt(const enzyme::Const & arg) { + return arg.value; + } + namespace detail { + template < typename RetType, typename ... T > + struct function_type + { + using type = RetType(T...); + }; + + template + struct templated_call { + + }; + + template + struct templated_call { + static RT wrap(T... args, function* __restrict__ f) { + return (*f)(args...); + } + }; + + + + template __attribute__((always_inline)) constexpr decltype(auto) push_return_last(T &&t); @@ -211,19 +286,19 @@ namespace enzyme { template struct autodiff_apply> { - template + template __attribute__((always_inline)) - static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { - return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); + static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence, ExtraArgs... args) { + return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))..., args...)); } }; template <> struct autodiff_apply { - template + template __attribute__((always_inline)) - static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { - return __enzyme_fwddiff(f, ret_attr, enzyme::get(impl::forward(t))...); + static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence, ExtraArgs... args) { + return __enzyme_fwddiff(f, ret_attr, enzyme::get(impl::forward(t))..., args...); } }; @@ -313,7 +388,7 @@ namespace enzyme { __attribute__((always_inline)) auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { using Tuple = enzyme::tuple< enz_arg_types ... >; - return detail::primal_apply_impl(f, impl::forward(arg_tup), std::make_index_sequence>{}); + return detail::primal_apply_impl(std::move(f), impl::forward(arg_tup), std::make_index_sequence>{}); } template < typename function, typename ... arg_types> @@ -321,29 +396,43 @@ namespace enzyme { return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); } - template < typename return_type, typename DiffMode, typename function, typename RetActivity, typename ... enz_arg_types > + template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t::type>, int> = 0> __attribute__((always_inline)) auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { using Tuple = enzyme::tuple< enz_arg_types ... >; - return detail::autodiff_apply::template impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + return detail::autodiff_apply::template impl((void*)&detail::templated_call::wrap, detail::ret_global::value, + impl::forward(arg_tup), + std::make_index_sequence>{}, &enzyme_const, &f); + } + + template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t::type>, int> = 0> + __attribute__((always_inline)) + auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::autodiff_apply::template impl((void*)static_cast(f), detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); } template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> __attribute__((always_inline)) auto autodiff(function && f, arg_types && ... args) { + using primal_return_type = decltype(f(primal_args_nt(args)...)); + using functy = typename detail::function_type::type; using return_type = typename autodiff_return::type; - return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); } template < typename DiffMode, typename function, typename ... arg_types> __attribute__((always_inline)) auto autodiff(function && f, arg_types && ... args) { - using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); + using primal_return_type = decltype(f(primal_args_nt(args)...)); + using functy = typename detail::function_type::type; using RetActivity = typename detail::default_ret_activity::type; using return_type = typename autodiff_return::type; - return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); } -} +#pragma clang diagnostic pop + +} // namespace enzyme }]>; def : Headers<"/enzymeroot/enzyme/type_traits", [{ @@ -410,6 +499,7 @@ def : Headers<"/enzymeroot/enzyme/tuple", [{ // constexpr support for std::tuple). Owning the implementation lets // us add __host__ __device__ annotations to any part of it +#include // for std::size_t #include // for std::integer_sequence #include @@ -417,6 +507,9 @@ def : Headers<"/enzymeroot/enzyme/tuple", [{ #define _NOEXCEPT noexcept namespace enzyme { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wmissing-braces" + template struct Index {}; @@ -468,10 +561,10 @@ template struct tuple_size; template -struct tuple_size> : std::integral_constant {}; +struct tuple_size> : std::integral_constant {}; template -static constexpr size_t tuple_size_v = tuple_size::value; +static constexpr std::size_t tuple_size_v = tuple_size::value; template __attribute__((always_inline)) @@ -484,7 +577,7 @@ namespace impl { template struct make_tuple_from_fwd_tuple; -template +template struct make_tuple_from_fwd_tuple> { template __attribute__((always_inline)) @@ -499,12 +592,12 @@ struct concat_with_fwd_tuple; template < typename Tuple > using iseq = std::make_index_sequence > >; -template +template struct concat_with_fwd_tuple, std::index_sequence> { template __attribute__((always_inline)) static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) { - return forward_as_tuple(get(impl::forward(fwd))..., get(impl::forward(t))...); + return enzyme::forward_as_tuple(get(impl::forward(fwd))..., get(impl::forward(t))...); } }; @@ -528,6 +621,8 @@ constexpr auto tuple_cat(Tuples&&... tuples) { return impl::tuple_cat(impl::forward(tuples)...); } +#pragma clang diagnostic pop + } // namespace enzyme #undef _NOEXCEPT }]>; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index dbc31ba9c7eb..9d72ddcbe8c3 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -566,8 +566,10 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret, } } - if (mode == DerivativeMode::ReverseModePrimal && - DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) { + if ((mode == DerivativeMode::ReverseModePrimal && + DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) || + (mode == DerivativeMode::ForwardMode && + DL.getTypeSizeInBits(retType) == DL.getTypeSizeInBits(diffretType))) { IRBuilder<> EB(CI->getFunction()->getEntryBlock().getFirstNonPHI()); auto AL = EB.CreateAlloca(retType); Builder.CreateStore(diffret, Builder.CreatePointerCast( @@ -592,9 +594,12 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret, } } + auto diffretsize = DL.getTypeSizeInBits(diffretType); + auto retsize = DL.getTypeSizeInBits(retType); EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI, "Cannot cast return type of gradient ", *diffretType, *diffret, - ", to desired type ", *retType); + " of size ", diffretsize, " bits ", ", to desired type ", + *retType, " of size ", retsize, " bits"); return false; } @@ -1190,10 +1195,14 @@ class EnzymeBase { if (auto arg = dyn_cast(res)) { loc = arg->getDebugLoc(); } + auto S = simplifyLoad(res); + if (!S) + S = res; EmitFailure("IllegalArgCast", loc, CI, "Cannot cast __enzyme_autodiff primal argument ", i, ", found ", *res, ", type ", *res->getType(), - " - to arg ", truei, " ", *PTy); + " (simplified to ", *S, " ) ", " - to arg ", truei, ", ", + *PTy); return {}; } } diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 7b8a53adb35f..476595de7e1b 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -23,6 +23,7 @@ // //===----------------------------------------------------------------------===// #include "Utils.h" +#include "GradientUtils.h" #include "TypeAnalysis/TypeAnalysis.h" #if LLVM_VERSION_MAJOR >= 16 @@ -2345,9 +2346,10 @@ findAllUsersOf(Value *AI) { // Given a pointer, find all values of size `valSz` which could be loaded from // that pointer when indexed at offset. If it is impossible to guarantee that // the set contains all such values, set legal to false -SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, - size_t valSz, bool &legal) { - SmallVector options; +SmallVector, 1> +getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, size_t valSz, + bool &legal) { + SmallVector, 1> options; auto todo = findAllUsersOf(ptr0); std::set> seen; @@ -2389,8 +2391,9 @@ SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, if (offset + valSz <= suboff) continue; - if (valSz == storeSz) { - options.push_back(SI->getValueOperand()); + if (valSz <= storeSz) { + assert(offset >= suboff); + options.emplace_back(SI->getValueOperand(), offset - suboff); continue; } } @@ -2410,7 +2413,9 @@ SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, legal = false; return options; } - for (auto subPtr : subPtrs) { + for (auto &&[subPtr, subOff] : subPtrs) { + if (subOff != 0) + return options; for (const auto &pair3 : findAllUsersOf(subPtr)) { todo.emplace_back(pair3); } @@ -2459,11 +2464,11 @@ SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, } // Perform mem2reg/sroa to identify the innermost value being represented. -Value *simplifyLoad(Value *V, size_t valSz) { +Value *simplifyLoad(Value *V, size_t valSz, size_t preOffset) { if (auto LI = dyn_cast(V)) { if (valSz == 0) { auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); - valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8; + valSz = (DL.getTypeSizeInBits(LI->getType()) + 7) / 8; } Value *ptr = LI->getPointerOperand(); @@ -2476,6 +2481,7 @@ Value *simplifyLoad(Value *V, size_t valSz) { if (!AI) { return nullptr; } + offset += preOffset; bool legal = true; auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal); @@ -2484,8 +2490,8 @@ Value *simplifyLoad(Value *V, size_t valSz) { return nullptr; } std::set res; - for (auto opt : opts) { - Value *v2 = simplifyLoad(opt, valSz); + for (auto &&[opt, startOff] : opts) { + Value *v2 = simplifyLoad(opt, valSz, startOff); if (v2) res.insert(v2); else @@ -2498,19 +2504,43 @@ Value *simplifyLoad(Value *V, size_t valSz) { return retval; } if (auto EVI = dyn_cast(V)) { - bool allZero = true; - for (auto idx : EVI->getIndices()) { - if (idx != 0) - allZero = false; - } - if (valSz == 0) { - auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); - valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8; - } - if (allZero) - if (auto LI = dyn_cast(EVI->getAggregateOperand())) { - return simplifyLoad(LI, valSz); + IRBuilder<> B(EVI); + auto em = + GradientUtils::extractMeta(B, EVI->getAggregateOperand(), + EVI->getIndices(), "", /*fallback*/ false); + if (em != nullptr) { + if (auto SL2 = simplifyLoad(em, valSz)) + em = SL2; + return em; + } + if (auto LI = dyn_cast(EVI->getAggregateOperand())) { + auto offset = preOffset; + + auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); + SmallVector vec; + vec.push_back(ConstantInt::get(Type::getInt64Ty(EVI->getContext()), 0)); + for (auto ind : EVI->getIndices()) { + vec.push_back( + ConstantInt::get(Type::getInt32Ty(EVI->getContext()), ind)); + } + auto ud = UndefValue::get( + PointerType::getUnqual(EVI->getOperand(0)->getType())); + auto g2 = + GetElementPtrInst::Create(EVI->getOperand(0)->getType(), ud, vec); + APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); + g2->accumulateConstantOffset(DL, ai); + // Using destructor rather than eraseFromParent + // as g2 has no parent + delete g2; + + offset += (size_t)ai.getLimitedValue(); + + if (valSz == 0) { + auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeSizeInBits(EVI->getType()) + 7) / 8; } + return simplifyLoad(LI, valSz, offset); + } } return nullptr; } diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index e72acd090cc4..bd427a49f57b 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1254,7 +1254,8 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Function *GetFunctionFromValue(llvm::Value *fn); -llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0); +llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0, + size_t preOffset = 0); static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) { auto F = getFunctionFromCall(CI); diff --git a/enzyme/test/Enzyme/ReverseMode/sroacall.ll b/enzyme/test/Enzyme/ReverseMode/sroacall.ll new file mode 100644 index 000000000000..b9fbe4e1fd47 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/sroacall.ll @@ -0,0 +1,54 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s + +%57 = type { i32*, i32, float } + +@enzyme_out = external dso_local global i32, align 4 + +define dso_local i32 @main() { +bb: + + %i768 = alloca %57, align 8 + + + %i750 = alloca %57, align 8 + + %i2661 = getelementptr inbounds %57, %57* %i750, i32 0, i32 1 + %i2671 = load i32, i32* @enzyme_out, align 4 + store i32 %i2671, i32* %i2661, align 8 + + + %i2672 = bitcast %57* %i750 to i8* + %i2673 = getelementptr inbounds i8, i8* %i2672, i64 12 + %i2674 = bitcast i8* %i2673 to float* + store float 0.0, float* %i2674, align 4 + + %i2687 = bitcast %57* %i750 to { i32*, i64 }* + %i2688 = load { i32*, i64 }, { i32*, i64 }* %i2687, align 8 + + + %i2692 = extractvalue { i32*, i64 } %i2688, 1 + + + + %i2711 = bitcast %57* %i768 to { i32*, i64 }* + %i2712 = getelementptr inbounds { i32*, i64 }, { i32*, i64 }* %i2711, i32 0, i32 1 + store i64 %i2692, i64* %i2712, align 8 + + + %i2717 = getelementptr inbounds %57, %57* %i768, i32 0, i32 2 + + %i2719 = load float, float* %i2717, align 4 + + %i2720 = call float (...) @_Z17__enzyme_autodiffIN6enzyme5tupleIJNS1_IJfEEEEEEJPvPiS5_ifS5_P8overloadEET_DpT0_(i8* bitcast (float (float)* @_ZN6enzyme6detail14templated_callI8overloadFffEE4wrapEfPS2_ to i8*), metadata !"enzyme_out", float %i2719) + ret i32 0 +} + +define linkonce_odr dso_local float @_ZN6enzyme6detail14templated_callI8overloadFffEE4wrapEfPS2_(float %arg) { +bb: + ret float %arg +} + +declare dso_local float @_Z17__enzyme_autodiffIN6enzyme5tupleIJNS1_IJfEEEEEEJPvPiS5_ifS5_P8overloadEET_DpT0_(...) + +; CHECK: call { float } @diffe_ZN6enzyme6detail14templated_callI8overloadFffEE4wrapEfPS2_(float %i2719, float 1.000000e+00) diff --git a/enzyme/test/Integration/CMakeLists.txt b/enzyme/test/Integration/CMakeLists.txt index 5e491ba50cd2..f76a45cfda30 100644 --- a/enzyme/test/Integration/CMakeLists.txt +++ b/enzyme/test/Integration/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(CppSugar) add_subdirectory(ForwardMode) add_subdirectory(ForwardError) add_subdirectory(ForwardModeVector) diff --git a/enzyme/test/Integration/CppSugar/CMakeLists.txt b/enzyme/test/Integration/CppSugar/CMakeLists.txt new file mode 100644 index 000000000000..afc4875c1dd4 --- /dev/null +++ b/enzyme/test/Integration/CppSugar/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-integration-sugar "Running enzyme c++ sugar integration tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR} + ARGS -v +) + +set_target_properties(check-enzyme-integration-sugar PROPERTIES FOLDER "Tests") + +# add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} +# DEPENDS ${ENZYME_TEST_DEPS} +# ) diff --git a/enzyme/test/Integration/CppSugar/forward_mode.cpp b/enzyme/test/Integration/CppSugar/forward_mode.cpp new file mode 100644 index 000000000000..fdc470b01b63 --- /dev/null +++ b/enzyme/test/Integration/CppSugar/forward_mode.cpp @@ -0,0 +1,25 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +#include "../test_utils.h" + +#include + +double square(double x) { return x * x; } + +double dsquare(double x) { + return enzyme::get<0>(enzyme::autodiff(square, enzyme::Duplicated{x, 1.0})); + //return enzyme::get<0>(enzyme::forward(square, enzyme::Duplicated{x, 1.0})); +} + +int main() { + for(double i=1; i<5; i++) { + APPROX_EQ(dsquare(i), 2 * i, 1e-10); + } +} diff --git a/enzyme/test/Integration/CppSugar/gh_issue_1785.cpp b/enzyme/test/Integration/CppSugar/gh_issue_1785.cpp new file mode 100644 index 000000000000..57de82e38cab --- /dev/null +++ b/enzyme/test/Integration/CppSugar/gh_issue_1785.cpp @@ -0,0 +1,38 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +// XFAIL: * + +#include "../test_utils.h" + +#include +#include + +int main() { + auto elasticity_kernel = [](const std::vector &dudxi, + const std::vector &J, + const double &w) + { + auto r = dudxi; + return r; + }; + + std::vector dudxi(4), s_dudxi(4), J(4); + double w = 1.0; + + enzyme::get<0> + (enzyme::autodiff>> + (+elasticity_kernel, + enzyme::Duplicated *>{&dudxi, &s_dudxi}, + enzyme::Const *>{&J}, + enzyme::Const{&w})); + + return 0; +} \ No newline at end of file diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp index a57a994e1ca7..419c196d511c 100644 --- a/enzyme/test/Integration/ReverseMode/sugar.cpp +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -7,6 +7,8 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + #include "../test_utils.h" #include @@ -15,6 +17,11 @@ double foo(double x, double y) { return x * y; } double square(double x) { return x * x; } +struct overload { + double operator()(double x) { return x * 10; } + float operator()(float x) { return x * 2; } +}; + struct pair { double x; double y; @@ -81,6 +88,20 @@ int main() { APPROX_EQ(prim, 2.7*3.1, 1e-10); } + { + auto y = enzyme::autodiff(overload{}, enzyme::Active(3.1)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + printf("dmul %f\n", y1); + APPROX_EQ(y1, 10, 1e-10); + } + + { + auto y = enzyme::autodiff(overload{}, enzyme::Active(3.1f)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + printf("dmul %f\n", y1); + APPROX_EQ(y1, 2, 1e-10); + } + { auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, 3.1, enzyme_out, 2.7); printf("dmul2 %f %f\n", z1, z2);