Skip to content

Commit

Permalink
#2240: Fix issues with handlers being executed and payload not being …
Browse files Browse the repository at this point in the history
…initialized
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 3b9d697 commit 27438ea
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 43 deletions.
102 changes: 68 additions & 34 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ Rabenseifner<DataT, Op, ObjT, finalHandler>::Rabenseifner(
vrt_node_ = this_node_ - nprocs_rem_;
}

vt_debug_print(terse, allreduce, "Rabenseifner constructor\n");
initialize(generateNewId(), std::forward<Args>(data)...);
}

Expand Down Expand Up @@ -165,9 +164,9 @@ template <
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
>
void Rabenseifner<DataT, Op, ObjT, finalHandler>::executeFinalHan(size_t id) {
// theCB()->makeSend<finalHandler>(parent_proxy_[this_node_]).sendTuple(std::make_tuple(val_));
auto& state = states_.at(id);
vt_debug_print(terse, allreduce, "Rabenseifner executing final handler ID = {}\n", id);

parent_proxy_[this_node_].template invoke<finalHandler>(state.val_);
state.completed_ = true;
}
Expand All @@ -176,7 +175,6 @@ template <
typename DataT, template <typename Arg> class Op, typename ObjT,
auto finalHandler>
void Rabenseifner<DataT, Op, ObjT, finalHandler>::allreduce(size_t id) {
vt_debug_print(terse, allreduce, "Rabenseifner allreduce is_part_of_adjustment_group_ = {}\n", is_part_of_adjustment_group_);
if (is_part_of_adjustment_group_) {
adjustForPowerOfTwo(id);
} else {
Expand All @@ -193,7 +191,7 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwo(size_t id)
auto const partner = is_even_ ? this_node_ + 1 : this_node_ - 1;

vt_debug_print(
terse, allreduce, "Rabenseifner::adjustForPowerOfTwo: To Node {} ID = {}\n", partner, id
terse, allreduce, "Rabenseifner AdjustInitial (To {}): ID = {}\n", partner, id
);

if (is_even_) {
Expand Down Expand Up @@ -223,15 +221,23 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoRightHalf(

auto& state = states_[msg->id_];

if(not state.initialized_){
initializeState(msg->id_);
if (state.val_.empty()) {
if (not state.initialized_) {
vt_debug_print(
verbose, allreduce,
"Rabenseifner AdjustRightHalf (From {}): State not initialized ID {}!\n",
theContext()->getFromNodeCurrentTask(), msg->id_
);

initializeState(msg->id_);
}
state.right_adjust_message_ = promoteMsg(msg);

return;
}

vt_debug_print(
terse, allreduce, "Rabenseifner::adjustForPowerOfTwoRightHalf: From Node {} ID = {}\n",
terse, allreduce, "Rabenseifner AdjustRightHalf (From {}): ID = {}\n",
theContext()->getFromNodeCurrentTask(), msg->id_
);

Expand All @@ -252,15 +258,22 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoLeftHalf(
AllreduceRbnRawMsg<Scalar>* msg) {

auto& state = states_[msg->id_];
if(not state.initialized_){
initializeState(msg->id_);
if (state.val_.empty()) {
if (not state.initialized_) {
vt_debug_print(
verbose, allreduce,
"Rabenseifner AdjustLeftHalf (From {}): State not initialized ID {}!\n",
theContext()->getFromNodeCurrentTask(), msg->id_);

initializeState(msg->id_);
}
state.left_adjust_message_ = promoteMsg(msg);

return;
}

vt_debug_print(
terse, allreduce, "Rabenseifner::adjustForPowerOfTwoLeftHalf: From Node {} ID = {}\n",
terse, allreduce, "Rabenseifner AdjustLeftHalf (From {}): ID = {}\n",
theContext()->getFromNodeCurrentTask(), msg->id_
);

Expand All @@ -276,7 +289,7 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoFinalPart(
AllreduceRbnRawMsg<Scalar>* msg) {

vt_debug_print(
terse, allreduce, "Rabenseifner::adjustForPowerOfTwoFinalPart: From Node {} ID = {}\n",
terse, allreduce, "Rabenseifner AdjustFinal (From {}): ID = {}\n",
theContext()->getFromNodeCurrentTask(), msg->id_
);

Expand All @@ -295,7 +308,7 @@ template <
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
>
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterAllMessagesReceived(size_t id) {
auto& state = states_.at(id);
auto const& state = states_.at(id);

return std::all_of(
state.scatter_steps_recv_.cbegin(), state.scatter_steps_recv_.cbegin() + state.scatter_step_,
Expand All @@ -306,15 +319,15 @@ template <
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
>
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterIsDone(size_t id) {
auto& state = states_.at(id);
auto const& state = states_.at(id);
return (state.scatter_step_ == num_steps_) and (state.scatter_num_recv_ == num_steps_);
}

template <
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
>
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterIsReady(size_t id) {
auto& state = states_.at(id);
auto const& state = states_.at(id);
return ((is_part_of_adjustment_group_ and state.finished_adjustment_part_) and
state.scatter_step_ == 0) or
((state.scatter_mask_ < nprocs_pof2_) and scatterAllMessagesReceived(id));
Expand All @@ -326,12 +339,20 @@ template <
void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterTryReduce(
size_t id, int32_t step) {
auto& state = states_.at(id);
if (
(step < state.scatter_step_) and not state.scatter_steps_reduced_[step] and

auto do_reduce = (step < state.scatter_step_) and
not state.scatter_steps_reduced_[step] and
state.scatter_steps_recv_[step] and
std::all_of(
state.scatter_steps_reduced_.cbegin(), state.scatter_steps_reduced_.cbegin() + step,
[](auto const val) { return val; })) {
std::all_of(state.scatter_steps_reduced_.cbegin(),
state.scatter_steps_reduced_.cbegin() + step,
[](auto const val) { return val; });

vt_debug_print(
verbose, allreduce, "Rabenseifner ScatterTryReduce (Step = {} ID = {}): {}\n",
step, id, do_reduce
);

if (do_reduce) {
auto& in_msg = state.scatter_messages_.at(step);
auto& in_val = in_msg->val_;
for (uint32_t i = 0; i < in_msg->size_; i++) {
Expand All @@ -356,15 +377,16 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIter(size_t id) {

vt_debug_print(
terse, allreduce,
"Rabenseifner Scatter (Send step {}): To Node {} starting with idx = {} and "
"Rabenseifner Scatter (Send step {} to {}): Starting with idx = {} and "
"count "
"{} ID = {}\n",
state.scatter_step_, dest, state.s_index_[state.scatter_step_],
state.s_count_[state.scatter_step_], id
);

proxy_[dest].template send<&Rabenseifner::scatterReduceIterHandler>(
state.val_.data() + state.s_index_[state.scatter_step_], state.s_count_[state.scatter_step_], id, state.scatter_step_
state.val_.data() + state.s_index_[state.scatter_step_],
state.s_count_[state.scatter_step_], id, state.scatter_step_
);

state.scatter_mask_ <<= 1;
Expand All @@ -387,15 +409,35 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIterHandler(
AllreduceRbnRawMsg<Scalar>* msg) {
auto& state = states_[msg->id_];

if(not state.initialized_){
initializeState(msg->id_);
if (state.val_.empty()) {
if (not state.initialized_) {
vt_debug_print(
verbose, allreduce,
"Rabenseifner Scatter (Recv step {} from {}): State not initialized "
"for ID = "
"{}!\n",
msg->step_, theContext()->getFromNodeCurrentTask(), msg->id_);
initializeState(msg->id_);
}

state.scatter_messages_[msg->step_] = promoteMsg(msg);
state.scatter_steps_recv_[msg->step_] = true;
state.scatter_num_recv_++;

return;
}

vt_debug_print(
terse, allreduce,
"Rabenseifner Scatter (Recv step {} from {}): initialized = {} "
"scatter_mask_= {} nprocs_pof2_ = {}: scatterAllMessagesReceived() = {} "
"state.finished_adjustment_part_ = {}"
"idx = {} ID = {}\n",
msg->step_, theContext()->getFromNodeCurrentTask(), state.initialized_,
state.scatter_mask_, nprocs_pof2_, scatterAllMessagesReceived(msg->id_),
state.finished_adjustment_part_, state.r_index_[msg->step_], msg->id_
);

state.scatter_messages_[msg->step_] = promoteMsg(msg);
state.scatter_steps_recv_[msg->step_] = true;
state.scatter_num_recv_++;
Expand All @@ -406,14 +448,6 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIterHandler(

scatterTryReduce(msg->id_, msg->step_);

vt_debug_print(
terse, allreduce,
"Rabenseifner Scatter (Recv step {}): scatter_mask_= {} nprocs_pof2_ = {}: "
"idx = {} from {} ID = {}\n",
msg->step_, state.scatter_mask_, nprocs_pof2_, state.r_index_[msg->step_],
theContext()->getFromNodeCurrentTask(), msg->id_
);

if ((state.scatter_mask_ < nprocs_pof2_) and scatterAllMessagesReceived(msg->id_)) {
scatterReduceIter(msg->id_);
} else if (scatterIsDone(msg->id_)) {
Expand Down Expand Up @@ -516,9 +550,9 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::gatherIterHandler(
AllreduceRbnRawMsg<Scalar>* msg) {
auto& state = states_.at(msg->id_);
vt_debug_print(
terse, allreduce, "Rabenseifner Gather (step {}): Received idx = {} from {} ID = {}\n",
msg->step_, state.s_index_[msg->step_],
theContext()->getFromNodeCurrentTask(), msg->id_
terse, allreduce, "Rabenseifner Gather (Recv step {} from {}): idx = {} ID = {}\n",
msg->step_, theContext()->getFromNodeCurrentTask(), state.s_index_[msg->step_],
msg->id_
);

state.gather_messages_[msg->step_] = promoteMsg(msg);
Expand Down
31 changes: 22 additions & 9 deletions src/vt/collective/reduce/allreduce/recursive_doubling.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::initialize(
state.val_ = DataT{std::forward<Args>(data)...};

vt_debug_print(
terse, allreduce, "RecursiveDoubling Initialize: size {} ID {}\n", DataType::size(state.val_), id
terse, allreduce, "RecursiveDoubling Initialize: size {} ID {}\n",
DataType::size(state.val_), id
);
}

Expand All @@ -100,6 +101,9 @@ template <
auto finalHandler>
void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::initializeState(size_t id){
auto& state = states_[id];

vt_debug_print(terse, allreduce, "RecursiveDoubling initializing state for ID = {}\n", id);

state.messages_.resize(num_steps_, nullptr);
state.steps_recv_.resize(num_steps_, false);
state.steps_reduced_.resize(num_steps_, false);
Expand Down Expand Up @@ -129,8 +133,8 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwo(size_
auto& state = states_.at(id);
if (is_part_of_adjustment_group_ and not is_even_) {
vt_debug_print(
terse, allreduce, "RecursiveDoubling Part1: Sending to Node {} ID ={} \n", this_node_,
this_node_ - 1, id
terse, allreduce, "RecursiveDoubling AdjustInitial (To {}): ID = {} \n",
this_node_, this_node_ - 1, id
);

proxy_[this_node_ - 1]
Expand All @@ -148,8 +152,10 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::
adjustForPowerOfTwoHandler(AllreduceDblRawMsg<DataT>* msg) {

auto& state = states_[msg->id_];
if(not state.initialized_) {
initializeState(msg->id_);
if (DataType::size(state.val_) == 0) {
if (not state.initialized_) {
initializeState(msg->id_);
}
state.adjust_message_ = promoteMsg(msg);

return;
Expand Down Expand Up @@ -240,8 +246,13 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::tryReduce(size_t id, int3
[](const auto val) { return val; });

vt_debug_print(
terse, allreduce, "RecursiveDoubling Part2 (Reduce step {}): state.step_ = {} state.steps_reduced_[step] = {} state.steps_recv_[step] = {} all_msgs_received = {} ID = {} \n",
step, state.step_, static_cast<bool>(state.steps_reduced_[step]), static_cast<bool>(state.steps_recv_[step]), all_msgs_received, id);
terse, allreduce,
"RecursiveDoubling Part2 (Reduce step {}): state.step_ = {} "
"state.steps_reduced_[step] = {} state.steps_recv_[step] = {} "
"all_msgs_received = {} ID = {} \n",
step, state.step_, static_cast<bool>(state.steps_reduced_[step]),
static_cast<bool>(state.steps_recv_[step]), all_msgs_received, id
);

if (
(step < state.step_) and not state.steps_reduced_[step] and
Expand All @@ -259,8 +270,10 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::reduceIterHandler(
AllreduceDblRawMsg<DataT>* msg) {
auto& state = states_[msg->id_];

if(not state.initialized_){
initializeState(msg->id_);
if (DataType::size(state.val_) == 0) {
if (not state.initialized_) {
initializeState(msg->id_);
}
state.messages_.at(msg->step_) = promoteMsg(msg);
state.steps_recv_[msg->step_] = true;

Expand Down

0 comments on commit 27438ea

Please sign in to comment.