Skip to content

Commit

Permalink
refactor(expression): clean unnecessary code and fix the comparision …
Browse files Browse the repository at this point in the history
…between different extents types
  • Loading branch information
amitsingh19975 committed Feb 15, 2022
1 parent 2abe215 commit 51674d5
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 164 deletions.
6 changes: 3 additions & 3 deletions include/boost/numeric/ublas/tensor/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ protected :
constexpr tensor_expression(tensor_expression&&) noexcept = default;
explicit tensor_expression() = default;

/// @brief This the only way to access the protected move constructor of other expressions.
/// @brief This is the only way to access the protected move constructor of other expressions.
template<class, class> friend struct tensor_expression;
};

Expand Down Expand Up @@ -198,7 +198,7 @@ struct binary_tensor_expression
*/
constexpr binary_tensor_expression(binary_tensor_expression&& l) noexcept = default;

/// @brief This the only way to access the protected move constructor of other expressions.
/// @brief This is the only way to access the protected move constructor of other expressions.
template<class, class, class> friend struct unary_tensor_expression;
template<class, class, class, class> friend struct binary_tensor_expression;

Expand Down Expand Up @@ -271,7 +271,7 @@ struct unary_tensor_expression
*/
constexpr unary_tensor_expression(unary_tensor_expression&& l) noexcept = default;

/// @brief This the only way to access the protected move constructor of other expressions.
/// @brief This is the only way to access the protected move constructor of other expressions.
template<class, class, class> friend struct unary_tensor_expression;
template<class, class, class, class> friend struct binary_tensor_expression;

Expand Down
165 changes: 68 additions & 97 deletions include/boost/numeric/ublas/tensor/expression_evaluation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,52 +45,61 @@ struct unary_tensor_expression;

namespace boost::numeric::ublas::detail {

template<class T, class E>
template<typename T>
struct is_tensor_type
: std::false_type
{};

template<typename E>
struct is_tensor_type< tensor_core<E> >
: std::true_type
{};

template<class T>
static constexpr bool is_tensor_type_v = is_tensor_type< std::decay_t<T> >::value;

template<typename T>
struct has_tensor_types
: std::integral_constant< bool, same_exp<T,E> >
: is_tensor_type<T>
{};

template<class T, class E>
static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t<T>, std::decay_t<E> >::value;
template<class T>
static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t<T> >::value;

template<class T, class D>
struct has_tensor_types<T, tensor_expression<T,D>>
{
static constexpr bool value =
same_exp<T,D> ||
has_tensor_types<T, std::decay_t<D> >::value;
};
struct has_tensor_types< tensor_expression<T,D> >
: has_tensor_types< std::decay_t<D> >
{};

template<class T, class EL, class ER, class OP>
struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
{
static constexpr bool value =
same_exp<T,EL> ||
same_exp<T,ER> ||
has_tensor_types<T, std::decay_t<EL> >::value ||
has_tensor_types<T, std::decay_t<ER> >::value;
};
struct has_tensor_types< binary_tensor_expression<T,EL,ER,OP> >
: std::integral_constant< bool, has_tensor_types_v<EL> || has_tensor_types_v<ER> >
{};

template<class T, class E, class OP>
struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
{
static constexpr bool value =
same_exp<T,E> ||
has_tensor_types<T, std::decay_t<E> >::value;
};
struct has_tensor_types< unary_tensor_expression<T,E,OP> >
: has_tensor_types< std::decay_t<E> >
{};

} // namespace boost::numeric::ublas::detail


namespace boost::numeric::ublas::detail
{


// TODO: remove this place holder for the old ublas expression after we remove the
// support for them.
template<class E>
[[nodiscard]]
constexpr auto& retrieve_extents([[maybe_unused]] ublas_expression<E> const& /*unused*/) noexcept;

/** @brief Retrieves extents of the tensor_core
*
*/
template<class TensorEngine>
[[nodiscard]]
constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t) noexcept
{
return t.extents();
}
Expand All @@ -103,17 +112,14 @@ constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
*/
template<class T, class D>
[[nodiscard]]
constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr) noexcept
{
static_assert(has_tensor_types_v<T,tensor_expression<T,D>>,
static_assert(has_tensor_types_v<tensor_expression<T,D>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& cast_expr = expr();

if constexpr ( same_exp<T,D> )
return cast_expr.extents();
else
return retrieve_extents(cast_expr);

return retrieve_extents(cast_expr);
}

// Disable warning for unreachable code for MSVC compiler
Expand All @@ -129,24 +135,24 @@ constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
*/
template<class T, class EL, class ER, class OP>
[[nodiscard]]
constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr) noexcept
{
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
static_assert(has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
if constexpr ( is_tensor_type_v<EL> )
return lexpr.extents();

else if constexpr ( same_exp<T,ER> )
else if constexpr ( is_tensor_type_v<ER> )
return rexpr.extents();

else if constexpr ( has_tensor_types_v<T,EL> )
else if constexpr ( has_tensor_types_v<EL>)
return retrieve_extents(lexpr);

else if constexpr ( has_tensor_types_v<T,ER> )
else if constexpr ( has_tensor_types_v<ER>)
return retrieve_extents(rexpr);
}

Expand All @@ -162,19 +168,15 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
*/
template<class T, class E, class OP>
[[nodiscard]]
constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr) noexcept
{

static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
static_assert(has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
return uexpr.extents();

else if constexpr ( has_tensor_types_v<T,E> )
return retrieve_extents(uexpr);
return retrieve_extents(uexpr);
}

} // namespace boost::numeric::ublas::detail
Expand All @@ -184,91 +186,60 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)

namespace boost::numeric::ublas::detail {

// TODO: remove this place holder for the old ublas expression after we remove the
// support for them.
template<class E, std::size_t ... es>
[[nodiscard]] inline
constexpr auto all_extents_equal([[maybe_unused]] ublas_expression<E> const& /*unused*/, [[maybe_unused]] extents<es...> const& /*unused*/) noexcept
{
return true;
}

template<class EN, std::size_t ... es>
[[nodiscard]] inline
constexpr auto all_extents_equal(tensor_core<EN> const& t, extents<es...> const& e)
constexpr auto all_extents_equal(tensor_core<EN> const& t, extents<es...> const& e) noexcept
{
return ::operator==(e,t.extents());
}

template<class T, class D, std::size_t ... es>
[[nodiscard]]
constexpr auto all_extents_equal(tensor_expression<T,D> const& expr, extents<es...> const& e)
constexpr auto all_extents_equal(tensor_expression<T,D> const& expr, extents<es...> const& e) noexcept
{

static_assert(has_tensor_types_v<T,tensor_expression<T,D>>,
static_assert(has_tensor_types_v<tensor_expression<T,D>>,
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");

auto const& cast_expr = expr();

using ::operator==;
using ::operator!=;

if constexpr ( same_exp<T,D> )
if( e != cast_expr.extents() )
return false;

if constexpr ( has_tensor_types_v<T,D> )
if ( !all_extents_equal(cast_expr, e))
return false;

return true;

return all_extents_equal(cast_expr, e);
}

template<class T, class EL, class ER, class OP, std::size_t... es>
[[nodiscard]]
constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, extents<es...> const& e)
constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, extents<es...> const& e) noexcept
{
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
static_assert(has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");

using ::operator==;
using ::operator!=;

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
if(e != lexpr.extents())
return false;

if constexpr ( same_exp<T,ER> )
if(e != rexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,EL> )
if(!all_extents_equal(lexpr, e))
return false;

if constexpr ( has_tensor_types_v<T,ER> )
if(!all_extents_equal(rexpr, e))
return false;

return true;
return all_extents_equal(lexpr, e) &&
all_extents_equal(rexpr, e) ;
}


template<class T, class E, class OP, std::size_t... es>
[[nodiscard]]
constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, extents<es...> const& e)
constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, extents<es...> const& e) noexcept
{
static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
static_assert(has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");

using ::operator==;

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
if(e != uexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,E> )
if(!all_extents_equal(uexpr, e))
return false;

return true;
return all_extents_equal(uexpr, e);
}

} // namespace boost::numeric::ublas::detail
Expand Down
Loading

0 comments on commit 51674d5

Please sign in to comment.