Skip to content

Commit

Permalink
Make mapped functions work with nda::Array as well as nda::Scalar obj…
Browse files Browse the repository at this point in the history
…ects
  • Loading branch information
Thoemi09 authored and Wentzell committed Nov 4, 2024
1 parent 4269109 commit 3e5f225
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 198 deletions.
2 changes: 1 addition & 1 deletion c++/nda/blas/tools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace nda::blas {

/// Specialization of nda::blas::is_conj_array_expr for the conjugate lazy expressions.
template <MemoryArray A>
static constexpr bool is_conj_array_expr<expr_call<conj_f, A>> = true;
static constexpr bool is_conj_array_expr<expr_call<detail::conj_f, A>> = true;

// Specialization of nda::blas::is_conj_array_expr for cvref types.
template <typename A>
Expand Down
14 changes: 14 additions & 0 deletions c++/nda/map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ namespace nda {
EXPECTS(((as.shape() == a0.shape()) && ...)); // same shape
return {f, {std::forward<A0>(a0), std::forward<As>(as)...}};
}

/**
* @brief Function call operator that returns the result of the callable object applied to the scalar arguments.
*
* @tparam T0 First nda::Scalar argument type.
* @tparam Ts Rest of the nda::Scalar argument types.
* @param t0 First nda::Scalar argument.
* @param ts Rest of the nda::Scalar arguments.
* @return Result of the functor applied to the scalar arguments.
*/
template <Scalar T0, Scalar... Ts>
auto operator()(T0 a0, Ts... as) const {
return f(a0, as...);
}
};

/**
Expand Down
133 changes: 47 additions & 86 deletions c++/nda/mapped_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,116 +38,77 @@ namespace nda {
* @{
*/

/**
* @brief Get the real part of a scalar.
*
* @tparam T Scalar type.
* @param t Scalar value.
* @return Real part of the scalar.
*/
template <typename T>
auto real(T t)
requires(nda::is_scalar_v<T>)
{
if constexpr (is_complex_v<T>) {
return std::real(t);
} else {
return t;
namespace detail {

// Get the real part of a scalar.
template <typename T>
auto real(T t)
requires(nda::is_scalar_v<T>)
{
if constexpr (is_complex_v<T>) {
return std::real(t);
} else {
return t;
}
}
}

/**
* @brief Get the complex conjugate of a scalar.
*
* @tparam T Scalar type.
* @param t Scalar value.
* @return The given scalar if it is not complex, otherwise its complex conjugate.
*/
template <typename T>
auto conj(T t)
requires(nda::is_scalar_v<T>)
{
if constexpr (is_complex_v<T>) {
return std::conj(t);
} else {
return t;
// Get the complex conjugate of a scalar.
template <typename T>
auto conj(T t)
requires(nda::is_scalar_v<T>)
{
if constexpr (is_complex_v<T>) {
return std::conj(t);
} else {
return t;
}
}
}

/**
* @brief Get the squared absolute value of a double.
*
* @param x Double value.
* @return Squared absolute value of the given double.
*/
inline double abs2(double x) { return x * x; }
// Get the squared absolute value of a double.
inline double abs2(double x) { return x * x; }

/**
* @brief Get the squared absolute value of a std::complex<double>.
*
* @param z std::complex<double> value.
* @return Squared absolute value of the given complex number.
*/
inline double abs2(std::complex<double> z) { return (conj(z) * z).real(); }
// Get the squared absolute value of a std::complex<double>.
inline double abs2(std::complex<double> z) { return (conj(z) * z).real(); }

/**
* @brief Check if a std::complex<double> is NaN.
*
* @param z std::complex<double> value.
* @return True if either the real or imaginary part of the given complex number is `NaN`, false otherwise.
*/
inline bool isnan(std::complex<double> const &z) { return std::isnan(z.real()) or std::isnan(z.imag()); }
// Check if a std::complex<double> is NaN.
inline bool isnan(std::complex<double> const &z) { return std::isnan(z.real()) or std::isnan(z.imag()); }

/**
* @brief Calculate the integer power of an integer.
*
* @tparam T Integer type.
* @param x Base value.
* @param n Exponent value.
* @return The result of the base raised to the power of the exponent.
*/
template <typename T>
T pow(T x, int n)
requires(std::is_integral_v<T>)
{
T r = 1;
for (int i = 0; i < n; ++i) r *= x;
return r;
}
// Functor for nda::detail::conj.
struct conj_f {
auto operator()(auto const &x) const { return conj(x); };
};

} // namespace detail

/**
* @brief Lazy, coefficient-wise power function for nda::Array types.
* @brief Function pow for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types).
*
* @tparam A nda::Array type.
* @param a nda::Array object.
* @tparam A nda::ArrayOrScalar type..
* @param a nda::ArrayOrScalar object.
* @param p Exponent value.
* @return A lazy nda::expr_call object.
* @return A lazy nda::expr_call object (nda::Array) or the result of `std::pow` applied to the object (nda::Scalar).
*/
template <Array A>
template <ArrayOrScalar A>
auto pow(A &&a, double p) {
return nda::map([p](auto const &x) {
using std::pow;
return pow(x, p);
})(std::forward<A>(a));
}

/// Wrapper for nda::conj.
struct conj_f {
/// Function call operator that forwards the call to nda::conj.
auto operator()(auto const &x) const { return conj(x); };
};

/**
* @brief Lazy, coefficient-wise complex conjugate function for nda::Array types.
* @brief Function conj for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types with a complex
* value type).
*
* @tparam A nda::Array type.
* @param a nda::Array object.
* @return A lazy nda::expr_call object if the array is complex valued, otherwise the array itself.
* @tparam A nda::ArrayOrScalar type..
* @param a nda::ArrayOrScalar object.
* @return A lazy nda::expr_call object (nda::Array and complex valued), the forwarded input object (nda::Array and
* not complex valued) or the complex conjugate of the scalar input.
*/
template <Array A>
template <ArrayOrScalar A>
decltype(auto) conj(A &&a) {
if constexpr (is_complex_v<get_value_t<A>>)
return nda::map(conj_f{})(std::forward<A>(a));
return nda::map(detail::conj_f{})(std::forward<A>(a));
else
return std::forward<A>(a);
}
Expand Down
Loading

0 comments on commit 3e5f225

Please sign in to comment.