Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpp interface improvements #1789

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 117 additions & 22 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Return __enzyme_fwddiff(T...);

namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

struct nodiff{};

template<bool ReturnPrimal = false>
Expand All @@ -45,31 +48,58 @@ namespace enzyme {
template < typename T >
struct Active{
T value;
Active(T &&v) : value(v) {}
Active(const T& v) : value(v) {};
Active(T&& v) : value(v) {};
Active(const Active<T>&) = default;
Active(Active<T>&&) = default;
operator T&() { return value; }
};

template < typename T >
struct Duplicated{
T value;
T shadow;
Duplicated(T &&v, T&& s) : value(v), shadow(s) {}
Duplicated(const T& v, const T& s) : value(v), shadow(s) {};
Duplicated(T&& v, T&& s) : value(v), shadow(s) {};
Duplicated(const Duplicated<T>&) = default;
Duplicated(Duplicated<T>&&) = default;
};

template < typename T >
struct DuplicatedNoNeed{
T value;
T shadow;
DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {}
DuplicatedNoNeed(const T& v, const T& s) : value(v), shadow(s) {};
DuplicatedNoNeed(T&& v, T&& s) : value(v), shadow(s) {};
DuplicatedNoNeed(const DuplicatedNoNeed<T>&) = default;
DuplicatedNoNeed(DuplicatedNoNeed<T>&&) = default;
};

template < typename T >
struct Const{
T value;
Const(T &&v) : value(v) {}
Const(const T& v) : value(v) {};
Const(T&& v) : value(v) {};
Const(const Const<T>&) = default;
Const(Const<T>&&) = default;
operator T&() { return value; }
};

// CTAD available in C++17 or later
#if __cplusplus >= 201703L
template < typename T >
Active(T) -> Active<T>;

template < typename T >
Const(T) -> Const<T>;

template < typename T >
Duplicated(T,T) -> Duplicated<T>;

template < typename T >
DuplicatedNoNeed(T,T) -> DuplicatedNoNeed<T>;
#endif

template < typename T >
struct type_info {
static constexpr bool is_active = false;
Expand Down Expand Up @@ -189,7 +219,52 @@ namespace enzyme {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::Duplicated<T> & arg) {
return arg.value;
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::DuplicatedNoNeed<T> & arg) {
return arg.value;
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::Active<T> & arg) {
return arg.value;
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::Const<T> & arg) {
return arg.value;
}

namespace detail {
template < typename RetType, typename ... T >
struct function_type
{
using type = RetType(T...);
};

template<typename function, typename prevfunc>
struct templated_call {

};

template<typename function, typename RT, typename ...T>
struct templated_call<function, RT(T...)> {
static RT wrap(T... args, function* __restrict__ f) {
return (*f)(args...);
}
};




template<typename T>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(T &&t);
Expand All @@ -211,19 +286,19 @@ namespace enzyme {

template <bool Mode>
struct autodiff_apply<ReverseMode<Mode>> {
template <class return_type, class Tuple, std::size_t... I>
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...));
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...));
}
};

template <>
struct autodiff_apply<ForwardMode> {
template <class return_type, class Tuple, std::size_t... I>
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...);
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...);
}
};

Expand Down Expand Up @@ -313,37 +388,51 @@ namespace enzyme {
__attribute__((always_inline))
auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::primal_apply_impl<return_type>(f, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::primal_apply_impl<return_type>(std::move(f), impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename function, typename ... arg_types>
auto primal_call(function && f, arg_types && ... args) {
return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...));
}

template < typename return_type, typename DiffMode, typename function, typename RetActivity, typename ... enz_arg_types >
template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<!std::is_function_v<typename remove_cvref< function >::type>, int> = 0>
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)f, detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)&detail::templated_call<function, functy>::wrap, detail::ret_global<RetActivity>::value,
impl::forward<Tuple>(arg_tup),
std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{}, &enzyme_const, &f);
}

template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<std::is_function_v<typename remove_cvref< function >::type>, int> = 0>
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)static_cast<functy*>(f), detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using primal_return_type = decltype(f(primal_args_nt(args)...));
using functy = typename detail::function_type<primal_return_type, decltype(primal_args_nt(args))...>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}

template < typename DiffMode, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using primal_return_type = decltype(primal_call<function, arg_types...>(impl::forward<function>(f), impl::forward<arg_types>(args)...));
using primal_return_type = decltype(f(primal_args_nt(args)...));
using functy = typename detail::function_type<primal_return_type, decltype(primal_args_nt(args))...>::type;
using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}
}
#pragma clang diagnostic pop

} // namespace enzyme
}]>;

def : Headers<"/enzymeroot/enzyme/type_traits", [{
Expand Down Expand Up @@ -410,13 +499,17 @@ def : Headers<"/enzymeroot/enzyme/tuple", [{
// constexpr support for std::tuple). Owning the implementation lets
// us add __host__ __device__ annotations to any part of it

#include <cstddef> // for std::size_t
#include <utility> // for std::integer_sequence

#include <enzyme/type_traits>

#define _NOEXCEPT noexcept
namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

template <int i>
struct Index {};

Expand Down Expand Up @@ -468,10 +561,10 @@ template <typename Tuple>
struct tuple_size;

template <typename... T>
struct tuple_size<tuple<T...>> : std::integral_constant<size_t, sizeof...(T)> {};
struct tuple_size<tuple<T...>> : std::integral_constant<std::size_t, sizeof...(T)> {};

template <typename Tuple>
static constexpr size_t tuple_size_v = tuple_size<Tuple>::value;
static constexpr std::size_t tuple_size_v = tuple_size<Tuple>::value;

template <typename... T>
__attribute__((always_inline))
Expand All @@ -484,7 +577,7 @@ namespace impl {
template <typename index_seq>
struct make_tuple_from_fwd_tuple;

template <size_t... indices>
template <std::size_t... indices>
struct make_tuple_from_fwd_tuple<std::index_sequence<indices...>> {
template <typename FWD_TUPLE>
__attribute__((always_inline))
Expand All @@ -499,12 +592,12 @@ struct concat_with_fwd_tuple;
template < typename Tuple >
using iseq = std::make_index_sequence<tuple_size_v< enzyme::remove_cvref_t< Tuple > > >;

template <size_t... fwd_indices, size_t... indices>
template <std::size_t... fwd_indices, std::size_t... indices>
struct concat_with_fwd_tuple<std::index_sequence<fwd_indices...>, std::index_sequence<indices...>> {
template <typename FWD_TUPLE, typename TUPLE>
__attribute__((always_inline))
static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) {
return forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
return enzyme::forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
}
};

Expand All @@ -528,6 +621,8 @@ constexpr auto tuple_cat(Tuples&&... tuples) {
return impl::tuple_cat(impl::forward<Tuples>(tuples)...);
}

#pragma clang diagnostic pop

} // namespace enzyme
#undef _NOEXCEPT
}]>;
Expand Down
17 changes: 13 additions & 4 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,10 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret,
}
}

if (mode == DerivativeMode::ReverseModePrimal &&
DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) {
if ((mode == DerivativeMode::ReverseModePrimal &&
DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) ||
(mode == DerivativeMode::ForwardMode &&
DL.getTypeSizeInBits(retType) == DL.getTypeSizeInBits(diffretType))) {
IRBuilder<> EB(CI->getFunction()->getEntryBlock().getFirstNonPHI());
auto AL = EB.CreateAlloca(retType);
Builder.CreateStore(diffret, Builder.CreatePointerCast(
Expand All @@ -579,9 +581,12 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret,
}
}

auto diffretsize = DL.getTypeSizeInBits(diffretType);
auto retsize = DL.getTypeSizeInBits(retType);
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
"Cannot cast return type of gradient ", *diffretType, *diffret,
", to desired type ", *retType);
" of size ", diffretsize, " bits ", ", to desired type ",
*retType, " of size ", retsize, " bits");
return false;
}

Expand Down Expand Up @@ -1176,10 +1181,14 @@ class EnzymeBase {
if (auto arg = dyn_cast<Instruction>(res)) {
loc = arg->getDebugLoc();
}
auto S = simplifyLoad(res);
if (!S)
S = res;
EmitFailure("IllegalArgCast", loc, CI,
"Cannot cast __enzyme_autodiff primal argument ", i,
", found ", *res, ", type ", *res->getType(),
" - to arg ", truei, " ", *PTy);
" (simplified to ", *S, " ) ", " - to arg ", truei, ", ",
*PTy);
return {};
}
}
Expand Down
Loading
Loading