Skip to content

Commit

Permalink
refactor(expression): clean unnecessary code and fix the comparision …
Browse files Browse the repository at this point in the history
…between different extents types
  • Loading branch information
amitsingh19975 committed Feb 15, 2022
1 parent 2abe215 commit 0ae5de2
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 155 deletions.
162 changes: 71 additions & 91 deletions include/boost/numeric/ublas/tensor/expression_evaluation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,52 +45,61 @@ struct unary_tensor_expression;

namespace boost::numeric::ublas::detail {

template<class T, class E>
template<typename T>
struct is_tensor_type
: std::false_type
{};

template<typename E>
struct is_tensor_type< tensor_core<E> >
: std::true_type
{};

template<class T>
static constexpr bool is_tensor_type_v = is_tensor_type< std::decay_t<T> >::value;

template<typename T>
struct has_tensor_types
: std::integral_constant< bool, same_exp<T,E> >
: is_tensor_type<T>
{};

template<class T, class E>
static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t<T>, std::decay_t<E> >::value;
template<class T>
static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t<T> >::value;

template<class T, class D>
struct has_tensor_types<T, tensor_expression<T,D>>
{
static constexpr bool value =
same_exp<T,D> ||
has_tensor_types<T, std::decay_t<D> >::value;
};
struct has_tensor_types< tensor_expression<T,D> >
: has_tensor_types< std::decay_t<D> >
{};

template<class T, class EL, class ER, class OP>
struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
{
static constexpr bool value =
same_exp<T,EL> ||
same_exp<T,ER> ||
has_tensor_types<T, std::decay_t<EL> >::value ||
has_tensor_types<T, std::decay_t<ER> >::value;
};
struct has_tensor_types< binary_tensor_expression<T,EL,ER,OP> >
: std::integral_constant< bool, has_tensor_types_v<EL> || has_tensor_types_v<ER> >
{};

template<class T, class E, class OP>
struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
{
static constexpr bool value =
same_exp<T,E> ||
has_tensor_types<T, std::decay_t<E> >::value;
};
struct has_tensor_types< unary_tensor_expression<T,E,OP> >
: has_tensor_types< std::decay_t<E> >
{};

} // namespace boost::numeric::ublas::detail


namespace boost::numeric::ublas::detail
{


// TODO: remove this place holder for the old ublas expression after we remove the
// support for them.
template<class E>
[[nodiscard]]
constexpr auto& retrieve_extents(ublas_expression<E> const&) noexcept;

/** @brief Retrieves extents of the tensor_core
*
*/
template<class TensorEngine>
[[nodiscard]]
constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t) noexcept
{
return t.extents();
}
Expand All @@ -103,17 +112,14 @@ constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
*/
template<class T, class D>
[[nodiscard]]
constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr) noexcept
{
static_assert(has_tensor_types_v<T,tensor_expression<T,D>>,
static_assert(has_tensor_types_v<tensor_expression<T,D>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& cast_expr = expr();

if constexpr ( same_exp<T,D> )
return cast_expr.extents();
else
return retrieve_extents(cast_expr);

return retrieve_extents(cast_expr);
}

// Disable warning for unreachable code for MSVC compiler
Expand All @@ -129,24 +135,24 @@ constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
*/
template<class T, class EL, class ER, class OP>
[[nodiscard]]
constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr) noexcept
{
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
static_assert(has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
return lexpr.extents();
if constexpr ( is_tensor_type_v<EL> )
return retrieve_extents(lexpr);

else if constexpr ( same_exp<T,ER> )
return rexpr.extents();
else if constexpr ( is_tensor_type_v<ER> )
return retrieve_extents(rexpr);

else if constexpr ( has_tensor_types_v<T,EL> )
else if constexpr ( has_tensor_types_v<EL> )
return retrieve_extents(lexpr);

else if constexpr ( has_tensor_types_v<T,ER> )
else if constexpr ( has_tensor_types_v<ER> )
return retrieve_extents(rexpr);
}

Expand All @@ -162,19 +168,15 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
*/
template<class T, class E, class OP>
[[nodiscard]]
constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr) noexcept
{

static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
static_assert(has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
return uexpr.extents();

else if constexpr ( has_tensor_types_v<T,E> )
return retrieve_extents(uexpr);
return retrieve_extents(uexpr);
}

} // namespace boost::numeric::ublas::detail
Expand All @@ -184,89 +186,67 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)

namespace boost::numeric::ublas::detail {

// TODO: remove this place holder for the old ublas expression after we remove the
// support for them.
template<class E, std::size_t ... es>
[[nodiscard]] inline
constexpr auto all_extents_equal(ublas_expression<E> const&, extents<es...> const&) noexcept
{
return true;
}

template<class EN, std::size_t ... es>
[[nodiscard]] inline
constexpr auto all_extents_equal(tensor_core<EN> const& t, extents<es...> const& e)
constexpr auto all_extents_equal(tensor_core<EN> const& t, extents<es...> const& e) noexcept
{
return ::operator==(e,t.extents());
}

template<class T, class D, std::size_t ... es>
[[nodiscard]]
constexpr auto all_extents_equal(tensor_expression<T,D> const& expr, extents<es...> const& e)
constexpr auto all_extents_equal(tensor_expression<T,D> const& expr, extents<es...> const& e) noexcept
{

static_assert(has_tensor_types_v<T,tensor_expression<T,D>>,
static_assert(has_tensor_types_v<tensor_expression<T,D>>,
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");

auto const& cast_expr = expr();

using ::operator==;
using ::operator!=;

if constexpr ( same_exp<T,D> )
if( e != cast_expr.extents() )
return false;

if constexpr ( has_tensor_types_v<T,D> )
if ( !all_extents_equal(cast_expr, e))
return false;
if ( !all_extents_equal(cast_expr, e) )
return false;

return true;

}

template<class T, class EL, class ER, class OP, std::size_t... es>
[[nodiscard]]
constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, extents<es...> const& e)
constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, extents<es...> const& e) noexcept
{
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
static_assert(has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");

using ::operator==;
using ::operator!=;

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
if(e != lexpr.extents())
return false;

if constexpr ( same_exp<T,ER> )
if(e != rexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,EL> )
if(!all_extents_equal(lexpr, e))
return false;

if constexpr ( has_tensor_types_v<T,ER> )
if(!all_extents_equal(rexpr, e))
return false;
if( !all_extents_equal(lexpr, e) || !all_extents_equal(rexpr, e) )
return false;

return true;
}


template<class T, class E, class OP, std::size_t... es>
[[nodiscard]]
constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, extents<es...> const& e)
constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, extents<es...> const& e) noexcept
{
static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
static_assert(has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");

using ::operator==;

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
if(e != uexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,E> )
if(!all_extents_equal(uexpr, e))
return false;
if( !all_extents_equal(uexpr, e) )
return false;

return true;
}
Expand Down
Loading

0 comments on commit 0ae5de2

Please sign in to comment.