Skip to content

Commit

Permalink
cpp interface improvements (#1789)
Browse files Browse the repository at this point in the history
* suppress braces warning, add some new tests

* size_t -> std::size_t

* reorganizing

* Strengthen and fix c++ sugar

---------

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
samuelpmish and wsmoses authored Apr 2, 2024
1 parent 5b2f04c commit 57d07ed
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 49 deletions.
139 changes: 117 additions & 22 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Return __enzyme_fwddiff(T...);

namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

struct nodiff{};

template<bool ReturnPrimal = false>
Expand All @@ -45,31 +48,58 @@ namespace enzyme {
template < typename T >
struct Active{
T value;
Active(T &&v) : value(v) {}
Active(const T& v) : value(v) {};
Active(T&& v) : value(v) {};
Active(const Active<T>&) = default;
Active(Active<T>&&) = default;
operator T&() { return value; }
};

template < typename T >
struct Duplicated{
T value;
T shadow;
Duplicated(T &&v, T&& s) : value(v), shadow(s) {}
Duplicated(const T& v, const T& s) : value(v), shadow(s) {};
Duplicated(T&& v, T&& s) : value(v), shadow(s) {};
Duplicated(const Duplicated<T>&) = default;
Duplicated(Duplicated<T>&&) = default;
};

template < typename T >
struct DuplicatedNoNeed{
T value;
T shadow;
DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {}
DuplicatedNoNeed(const T& v, const T& s) : value(v), shadow(s) {};
DuplicatedNoNeed(T&& v, T&& s) : value(v), shadow(s) {};
DuplicatedNoNeed(const DuplicatedNoNeed<T>&) = default;
DuplicatedNoNeed(DuplicatedNoNeed<T>&&) = default;
};

template < typename T >
struct Const{
T value;
Const(T &&v) : value(v) {}
Const(const T& v) : value(v) {};
Const(T&& v) : value(v) {};
Const(const Const<T>&) = default;
Const(Const<T>&&) = default;
operator T&() { return value; }
};

// CTAD available in C++17 or later
#if __cplusplus >= 201703L
template < typename T >
Active(T) -> Active<T>;

template < typename T >
Const(T) -> Const<T>;

template < typename T >
Duplicated(T,T) -> Duplicated<T>;

template < typename T >
DuplicatedNoNeed(T,T) -> DuplicatedNoNeed<T>;
#endif

template < typename T >
struct type_info {
static constexpr bool is_active = false;
Expand Down Expand Up @@ -189,7 +219,52 @@ namespace enzyme {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::Duplicated<T> & arg) {
return arg.value;
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::DuplicatedNoNeed<T> & arg) {
return arg.value;
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::Active<T> & arg) {
return arg.value;
}

template < typename T >
__attribute__((always_inline))
auto primal_args_nt(const enzyme::Const<T> & arg) {
return arg.value;
}

namespace detail {
template < typename RetType, typename ... T >
struct function_type
{
using type = RetType(T...);
};

template<typename function, typename prevfunc>
struct templated_call {

};

template<typename function, typename RT, typename ...T>
struct templated_call<function, RT(T...)> {
static RT wrap(T... args, function* __restrict__ f) {
return (*f)(args...);
}
};




template<typename T>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(T &&t);
Expand All @@ -211,19 +286,19 @@ namespace enzyme {

template <bool Mode>
struct autodiff_apply<ReverseMode<Mode>> {
template <class return_type, class Tuple, std::size_t... I>
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...));
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...));
}
};

template <>
struct autodiff_apply<ForwardMode> {
template <class return_type, class Tuple, std::size_t... I>
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...);
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...);
}
};

Expand Down Expand Up @@ -313,37 +388,51 @@ namespace enzyme {
__attribute__((always_inline))
auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::primal_apply_impl<return_type>(f, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::primal_apply_impl<return_type>(std::move(f), impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename function, typename ... arg_types>
auto primal_call(function && f, arg_types && ... args) {
return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...));
}

template < typename return_type, typename DiffMode, typename function, typename RetActivity, typename ... enz_arg_types >
template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<!std::is_function_v<typename remove_cvref< function >::type>, int> = 0>
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)f, detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)&detail::templated_call<function, functy>::wrap, detail::ret_global<RetActivity>::value,
impl::forward<Tuple>(arg_tup),
std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{}, &enzyme_const, &f);
}

template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<std::is_function_v<typename remove_cvref< function >::type>, int> = 0>
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)static_cast<functy*>(f), detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using primal_return_type = decltype(f(primal_args_nt(args)...));
using functy = typename detail::function_type<primal_return_type, decltype(primal_args_nt(args))...>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}

template < typename DiffMode, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using primal_return_type = decltype(primal_call<function, arg_types...>(impl::forward<function>(f), impl::forward<arg_types>(args)...));
using primal_return_type = decltype(f(primal_args_nt(args)...));
using functy = typename detail::function_type<primal_return_type, decltype(primal_args_nt(args))...>::type;
using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}
}
#pragma clang diagnostic pop

} // namespace enzyme
}]>;

def : Headers<"/enzymeroot/enzyme/type_traits", [{
Expand Down Expand Up @@ -410,13 +499,17 @@ def : Headers<"/enzymeroot/enzyme/tuple", [{
// constexpr support for std::tuple). Owning the implementation lets
// us add __host__ __device__ annotations to any part of it

#include <cstddef> // for std::size_t
#include <utility> // for std::integer_sequence

#include <enzyme/type_traits>

#define _NOEXCEPT noexcept
namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

template <int i>
struct Index {};

Expand Down Expand Up @@ -468,10 +561,10 @@ template <typename Tuple>
struct tuple_size;

template <typename... T>
struct tuple_size<tuple<T...>> : std::integral_constant<size_t, sizeof...(T)> {};
struct tuple_size<tuple<T...>> : std::integral_constant<std::size_t, sizeof...(T)> {};

template <typename Tuple>
static constexpr size_t tuple_size_v = tuple_size<Tuple>::value;
static constexpr std::size_t tuple_size_v = tuple_size<Tuple>::value;

template <typename... T>
__attribute__((always_inline))
Expand All @@ -484,7 +577,7 @@ namespace impl {
template <typename index_seq>
struct make_tuple_from_fwd_tuple;

template <size_t... indices>
template <std::size_t... indices>
struct make_tuple_from_fwd_tuple<std::index_sequence<indices...>> {
template <typename FWD_TUPLE>
__attribute__((always_inline))
Expand All @@ -499,12 +592,12 @@ struct concat_with_fwd_tuple;
template < typename Tuple >
using iseq = std::make_index_sequence<tuple_size_v< enzyme::remove_cvref_t< Tuple > > >;

template <size_t... fwd_indices, size_t... indices>
template <std::size_t... fwd_indices, std::size_t... indices>
struct concat_with_fwd_tuple<std::index_sequence<fwd_indices...>, std::index_sequence<indices...>> {
template <typename FWD_TUPLE, typename TUPLE>
__attribute__((always_inline))
static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) {
return forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
return enzyme::forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
}
};

Expand All @@ -528,6 +621,8 @@ constexpr auto tuple_cat(Tuples&&... tuples) {
return impl::tuple_cat(impl::forward<Tuples>(tuples)...);
}

#pragma clang diagnostic pop

} // namespace enzyme
#undef _NOEXCEPT
}]>;
Expand Down
17 changes: 13 additions & 4 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,10 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret,
}
}

if (mode == DerivativeMode::ReverseModePrimal &&
DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) {
if ((mode == DerivativeMode::ReverseModePrimal &&
DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) ||
(mode == DerivativeMode::ForwardMode &&
DL.getTypeSizeInBits(retType) == DL.getTypeSizeInBits(diffretType))) {
IRBuilder<> EB(CI->getFunction()->getEntryBlock().getFirstNonPHI());
auto AL = EB.CreateAlloca(retType);
Builder.CreateStore(diffret, Builder.CreatePointerCast(
Expand All @@ -592,9 +594,12 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret,
}
}

auto diffretsize = DL.getTypeSizeInBits(diffretType);
auto retsize = DL.getTypeSizeInBits(retType);
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
"Cannot cast return type of gradient ", *diffretType, *diffret,
", to desired type ", *retType);
" of size ", diffretsize, " bits ", ", to desired type ",
*retType, " of size ", retsize, " bits");
return false;
}

Expand Down Expand Up @@ -1190,10 +1195,14 @@ class EnzymeBase {
if (auto arg = dyn_cast<Instruction>(res)) {
loc = arg->getDebugLoc();
}
auto S = simplifyLoad(res);
if (!S)
S = res;
EmitFailure("IllegalArgCast", loc, CI,
"Cannot cast __enzyme_autodiff primal argument ", i,
", found ", *res, ", type ", *res->getType(),
" - to arg ", truei, " ", *PTy);
" (simplified to ", *S, " ) ", " - to arg ", truei, ", ",
*PTy);
return {};
}
}
Expand Down
Loading

0 comments on commit 57d07ed

Please sign in to comment.