Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply #1295 #1337

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 25 additions & 40 deletions libs/pika/async_cuda/include/pika/async_cuda/then_with_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,21 @@ namespace pika::cuda::experimental::then_with_stream_detail {
}

template <typename... Ts>
auto set_value(Ts&&... ts) noexcept
auto set_value(Ts&&... ts) && noexcept
-> decltype(PIKA_INVOKE(PIKA_MOVE(f), op_state.sched, stream.value(), ts...),
void())
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() mutable {
using ts_element_type = std::tuple<std::decay_t<Ts>...>;
op_state.ts.template emplace<ts_element_type>(PIKA_FORWARD(Ts, ts)...);
[[maybe_unused]] auto& t = std::get<ts_element_type>(op_state.ts);
r.op_state.ts.template emplace<ts_element_type>(
PIKA_FORWARD(Ts, ts)...);
[[maybe_unused]] auto& t = std::get<ts_element_type>(r.op_state.ts);

if (!op_state.stream)
if (!r.op_state.stream)
{
op_state.stream.emplace(op_state.sched.get_next_stream());
r.op_state.stream.emplace(r.op_state.sched.get_next_stream());
}

// If the next receiver is also a
Expand All @@ -272,11 +274,11 @@ namespace pika::cuda::experimental::then_with_stream_detail {
if constexpr (is_then_with_cuda_stream_receiver<
std::decay_t<Receiver>>::value)
{
if (op_state.sched == op_state.receiver.op_state.sched)
if (r.op_state.sched == r.op_state.receiver.op_state.sched)
{
PIKA_ASSERT(op_state.stream);
PIKA_ASSERT(!op_state.receiver.op_state.stream);
op_state.receiver.op_state.stream = op_state.stream;
PIKA_ASSERT(r.op_state.stream);
PIKA_ASSERT(!r.op_state.receiver.op_state.stream);
r.op_state.receiver.op_state.stream = r.op_state.stream;

successor_uses_same_stream = true;
}
Expand All @@ -290,8 +292,8 @@ namespace pika::cuda::experimental::then_with_stream_detail {
{
std::apply(
[&](auto&... ts) mutable {
PIKA_INVOKE(PIKA_MOVE(op_state.f), op_state.sched,
op_state.stream.value(), ts...);
PIKA_INVOKE(PIKA_MOVE(r.op_state.f), r.op_state.sched,
r.op_state.stream.value(), ts...);
},
t);

Expand All @@ -307,14 +309,14 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// stream when a
// non-then_with_cuda_stream receiver is
// connected.
set_value_immediate_void(op_state);
set_value_immediate_void(r.op_state);
}
else
{
// When the streams are different, we
// add a callback which will call
// set_value on the receiver.
set_value_event_callback_void(op_state);
set_value_event_callback_void(r.op_state);
}
}
else
Expand All @@ -323,16 +325,16 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// then_with_cuda_stream_receiver, we add a
// callback which will call set_value on the
// receiver.
set_value_event_callback_void(op_state);
set_value_event_callback_void(r.op_state);
}
}
else
{
std::apply(
[&](auto&... ts) mutable {
op_state.result.template emplace<invoke_result_type>(
PIKA_INVOKE(PIKA_MOVE(op_state.f), op_state.sched,
op_state.stream.value(), ts...));
r.op_state.result.template emplace<invoke_result_type>(
PIKA_INVOKE(PIKA_MOVE(r.op_state.f), r.op_state.sched,
r.op_state.stream.value(), ts...));
},
t);

Expand All @@ -348,15 +350,16 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// stream when a
// non-then_with_cuda_stream receiver is
// connected.
set_value_immediate_non_void<invoke_result_type>(op_state);
set_value_immediate_non_void<invoke_result_type>(
r.op_state);
}
else
{
// When the streams are different, we
// add a callback which will call
// set_value on the receiver.
set_value_event_callback_non_void<invoke_result_type>(
op_state);
r.op_state);
}
}
else
Expand All @@ -365,13 +368,14 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// then_with_cuda_stream_receiver, we add a
// callback which will call set_value on the
// receiver.
set_value_event_callback_non_void<invoke_result_type>(op_state);
set_value_event_callback_non_void<invoke_result_type>(
r.op_state);
}
}
},
[&](std::exception_ptr ep) mutable {
pika::execution::experimental::set_error(
PIKA_MOVE(op_state.receiver), PIKA_MOVE(ep));
PIKA_MOVE(r.op_state.receiver), PIKA_MOVE(ep));
});
}

Expand All @@ -383,25 +387,6 @@ namespace pika::cuda::experimental::then_with_stream_detail {
}
};

// This should be a hidden friend in then_with_cuda_stream_receiver.
// However, nvcc does not know how to compile it with some argument
// types ("error: no instance of overloaded function std::forward
// matches the argument list").
template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
then_with_cuda_stream_receiver&& r, Ts&&... ts) noexcept
-> decltype(r.set_value(PIKA_FORWARD(Ts, ts)...))
{
// nvcc fails to compile this with std::forward<Ts>(ts)... or
// static_cast<Ts&&>(ts)... so we explicitly use
// static_cast<decltype(ts)>(ts)... as a workaround.
#if defined(PIKA_HAVE_CUDA)
r.set_value(static_cast<decltype(ts)&&>(ts)...);
#else
r.set_value(PIKA_FORWARD(Ts, ts)...);
#endif
}

using operation_state_type =
pika::execution::experimental::connect_result_t<std::decay_t<Sender>,
then_with_cuda_stream_receiver>;
Expand Down
4 changes: 2 additions & 2 deletions libs/pika/async_mpi/include/pika/async_mpi/dispatch_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ namespace pika::mpi::experimental::detail {
// otherwise return the request by passing it to set_value
template <typename... Ts,
typename = std::enable_if_t<is_mpi_request_invocable_v<F, Ts...>>>
friend constexpr void
tag_invoke(ex::set_value_t, dispatch_mpi_receiver r, Ts&&... ts) noexcept
constexpr void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() mutable {
using invoke_result_type = mpi_request_invoke_result_t<F, Ts...>;
Expand Down
5 changes: 3 additions & 2 deletions libs/pika/async_mpi/include/pika/async_mpi/trigger_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ namespace pika::mpi::experimental::detail {

// receive the MPI Request and set a callback to be
// triggered when the mpi request completes
friend constexpr void tag_invoke(
ex::set_value_t, trigger_mpi_receiver r, MPI_Request request) noexcept
constexpr void set_value(MPI_Request request) && noexcept
{
auto r = PIKA_MOVE(*this);

// early exit check
if (request == MPI_REQUEST_NULL)
{
Expand Down
23 changes: 5 additions & 18 deletions libs/pika/execution/include/pika/execution/algorithms/bulk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,33 +95,20 @@ namespace pika::bulk_detail {
}

template <typename... Ts>
void set_value(Ts&&... ts)
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() {
for (auto const& s : shape) { PIKA_INVOKE(f, s, ts...); }
for (auto const& s : r.shape) { PIKA_INVOKE(r.f, s, ts...); }
pika::execution::experimental::set_value(
PIKA_MOVE(receiver), PIKA_FORWARD(Ts, ts)...);
PIKA_MOVE(r.receiver), PIKA_FORWARD(Ts, ts)...);
},
[&](std::exception_ptr ep) {
pika::execution::experimental::set_error(
PIKA_MOVE(receiver), PIKA_MOVE(ep));
PIKA_MOVE(r.receiver), PIKA_MOVE(ep));
});
}

template <typename... Ts>
friend auto tag_invoke(
pika::execution::experimental::set_value_t, bulk_receiver&& r, Ts&&... ts) noexcept
-> decltype(pika::execution::experimental::set_value(
std::declval<std::decay_t<Receiver>&&>(), PIKA_FORWARD(Ts, ts)...),
void())
{
// set_value is in a member function only because of a
// compiler bug in GCC 7. When the body of set_value is
// inlined here compilation fails with an internal compiler
// error.
r.set_value(PIKA_FORWARD(Ts, ts)...);
}
};

template <typename Receiver>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ namespace pika::drop_op_state_detail {
};

template <typename... Ts>
friend void tag_invoke(pika::execution::experimental::set_value_t,
drop_op_state_receiver_type r, Ts&&... ts) noexcept
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);

PIKA_ASSERT(r.op_state != nullptr);
PIKA_ASSERT(r.op_state->op_state.has_value());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ namespace pika::drop_value_detail {
}

template <typename... Ts>
friend void tag_invoke(pika::execution::experimental::set_value_t,
drop_value_receiver_type&& r, Ts&&...) noexcept
void set_value(Ts&&...) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::execution::experimental::set_value(PIKA_MOVE(r.receiver));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,14 @@ namespace pika::ensure_started_detail {
#endif

template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
ensure_started_receiver r, Ts&&... ts) noexcept
auto set_value(Ts&&... ts) && noexcept
-> decltype(std::declval<
pika::detail::variant<pika::detail::monostate, value_type>>()
.template emplace<value_type>(
std::make_tuple<>(PIKA_FORWARD(Ts, ts)...)),
void())
{
auto r = PIKA_MOVE(*this);
r.state->v.template emplace<value_type>(
std::make_tuple<>(PIKA_FORWARD(Ts, ts)...));
r.state->set_predecessor_done();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ namespace pika::let_error_detail {
template <typename... Ts,
typename = std::enable_if_t<std::is_invocable_v<
pika::execution::experimental::set_value_t, Receiver&&, Ts...>>>
friend void tag_invoke(pika::execution::experimental::set_value_t,
let_error_predecessor_receiver&& r, Ts&&... ts) noexcept
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::execution::experimental::set_value(
PIKA_MOVE(r.receiver), PIKA_FORWARD(Ts, ts)...);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,37 +231,27 @@ namespace pika::let_value_detail {
pika::detail::monostate>;

template <typename... Ts>
void set_value(Ts&&... ts)
auto set_value(Ts&&... ts) && noexcept
-> decltype(std::declval<predecessor_ts_type>()
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...),
void())
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() {
op_state.predecessor_ts
r.op_state.predecessor_ts
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...);
pika::detail::visit(
set_value_visitor{PIKA_MOVE(receiver), PIKA_MOVE(f), op_state},
op_state.predecessor_ts);
pika::detail::visit(set_value_visitor{PIKA_MOVE(r.receiver),
PIKA_MOVE(r.f), r.op_state},
r.op_state.predecessor_ts);
},
[&](std::exception_ptr ep) {
pika::execution::experimental::set_error(
PIKA_MOVE(receiver), PIKA_MOVE(ep));
PIKA_MOVE(r.receiver), PIKA_MOVE(ep));
});
}

template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
let_value_predecessor_receiver&& r, Ts&&... ts) noexcept
-> decltype(std::declval<predecessor_ts_type>()
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...),
void())
{
// set_value is in a member function only because of a
// compiler bug in GCC 7. When the body of set_value is
// inlined here compilation fails with an internal
// compiler error.
r.set_value(PIKA_FORWARD(Ts, ts)...);
}
};

template <typename PredecessorSender_, typename Receiver_, typename F_>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ namespace pika {
};

template <typename... Ts>
friend void tag_invoke(pika::execution::experimental::set_value_t,
require_started_receiver_type r, Ts&&... ts) noexcept
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
PIKA_ASSERT(r.op_state != nullptr);
pika::execution::experimental::set_value(
PIKA_MOVE(r.op_state->receiver), PIKA_FORWARD(Ts, ts)...);
Expand Down Expand Up @@ -381,8 +381,7 @@ namespace pika {

s.connected = true;
return
{
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
{ // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
*std::exchange(s.sender, std::nullopt), PIKA_FORWARD(Receiver, receiver)
#if defined(PIKA_DETAIL_HAVE_REQUIRE_STARTED_MODE)
,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,21 +181,14 @@ namespace pika::schedule_from_detail {
pika::detail::monostate>;

template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
predecessor_sender_receiver&& r, Ts&&... ts) noexcept
auto set_value(Ts&&... ts) && noexcept
-> decltype(std::declval<value_type>()
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...),
std::forward<Ts>(ts)...),
void())
{
// nvcc fails to compile this with std::forward<Ts>(ts)...
// or static_cast<Ts&&>(ts)... so we explicitly use
// static_cast<decltype(ts)>(ts)... as a workaround.
# if defined(PIKA_HAVE_CUDA)
r.op_state.set_value_predecessor_sender(static_cast<decltype(ts)&&>(ts)...);
# else
r.op_state.set_value_predecessor_sender(PIKA_FORWARD(Ts, ts)...);
# endif
auto r = std::move(*this);
r.op_state.set_value_predecessor_sender(std::forward<Ts>(ts)...);
}
};

Expand Down Expand Up @@ -252,9 +245,9 @@ namespace pika::schedule_from_detail {
r.op_state.set_stopped_scheduler_sender();
}

friend void tag_invoke(pika::execution::experimental::set_value_t,
scheduler_sender_receiver&& r) noexcept
void set_value() && noexcept
{
auto r = std::move(*this);
r.op_state.set_value_scheduler_sender();
}
};
Expand Down
Loading
Loading