Skip to content

Commit

Permalink
#2240: Working Recursive doubling
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 19df2f7 commit 831168f
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 206 deletions.
13 changes: 8 additions & 5 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ struct AllreduceRbnMsg
int32_t step_ = {};
};

template <typename DataT>
template <
typename DataT, template <typename Arg> class Op, typename ObjT,
auto finalHandler>
struct Rabenseifner {
void initialize(
const DataT& data, vt::objgroup::proxy::Proxy<Rabenseifner> proxy,
uint32_t num_nodes) {
vt::objgroup::proxy::Proxy<ObjT> parentProxy, uint32_t num_nodes) {
this_node_ = vt::theContext()->getNode();
is_even_ = this_node_ % 2 == 0;
val_ = data;
Expand Down Expand Up @@ -145,8 +147,8 @@ struct Rabenseifner {

if constexpr (debug) {
fmt::print(
"[{}] Initialize with size = {} num_steps {} \n", this_node_,
w_size_, num_steps_);
"[{}] Initialize with size = {} num_steps {} \n", this_node_, w_size_,
num_steps_);
}
}

Expand Down Expand Up @@ -186,7 +188,7 @@ struct Rabenseifner {
val_[(val_.size() / 2) + i] = msg->val_[i];
}

// partTwo();
partTwo();
}

void partTwo() {
Expand Down Expand Up @@ -350,6 +352,7 @@ struct Rabenseifner {
NodeType this_node_ = {};
bool is_even_ = false;
vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parentProxy_ = {};
DataT val_ = {};
NodeType vrt_node_ = {};
bool is_part_of_adjustment_group_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,23 @@ struct AllreduceDblMsg
int32_t step_ = {};
};

template <typename DataT, typename Op, typename ObjT, auto finalHandler>
template <
typename DataT, template <typename Arg> class Op, typename ObjT,
auto finalHandler>
struct DistanceDoubling {
void initialize(
const DataT& data, vt::objgroup::proxy::Proxy<DistanceDoubling> proxy,
vt::objgroup::proxy::Proxy<ObjT> parentProxy,
uint32_t num_nodes) {
template <typename... Args>
DistanceDoubling(NodeType num_nodes, Args&&... args)
: val_(std::forward<Args>(args)...),
num_nodes_(num_nodes) { }

void initialize() {
this_node_ = vt::theContext()->getNode();
is_even_ = this_node_ % 2 == 0;
val_ = data;
proxy_ = proxy;
parentProxy_ = parentProxy;
num_steps_ = static_cast<int32_t>(log2(num_nodes));
num_steps_ = static_cast<int32_t>(log2(num_nodes_));
messages.resize(num_steps_, nullptr);

nprocs_pof2_ = 1 << num_steps_;
nprocs_rem_ = num_nodes - nprocs_pof2_;
nprocs_rem_ = num_nodes_ - nprocs_pof2_;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
if (is_even_) {
Expand All @@ -114,41 +115,76 @@ struct DistanceDoubling {
vrt_node_ = this_node_ - nprocs_rem_;
}

w_size_ = data.size();

expected_send_ = num_steps_;
expected_recv_ = num_steps_;
steps_sent_.resize(num_steps_, false);
steps_recv_.resize(num_steps_, false);
steps_reduced_.resize(num_steps_, false);

initialized_ = true;
}

void allreduce(
vt::objgroup::proxy::Proxy<DistanceDoubling> proxy,
vt::objgroup::proxy::Proxy<ObjT> parentProxy) {
if (not initialized_) {
initialize();
}

proxy_ = proxy;
parent_proxy_ = parentProxy;

if (nprocs_rem_) {
adjustForPowerOfTwo();
} else {
reduceIter();
}
}

void partOne() {
if (not nprocs_rem_) {
// we're running on power of 2 number of nodes, proceed to second step
partTwo();
} else if (is_part_of_adjustment_group_ and not is_even_) {
proxy_[this_node_ - 1].template send<&DistanceDoubling::partOneHandler>(
val_);
void adjustForPowerOfTwo() {
if (is_part_of_adjustment_group_ and not is_even_) {
if constexpr (isdebug) {
fmt::print(
"[{}] Part1: Sending to Node {} \n", this_node_, this_node_ - 1);
}

proxy_[this_node_ - 1]
.template send<&DistanceDoubling::adjustForPowerOfTwoHandler>(val_);
}
}

void partOneHandler(AllreduceDblMsg<DataT>* msg) {
Op(val_, msg->val_);
// for (int i = 0; i < msg->val_.size(); i++) {
// val_[i] += msg->val_[i];
// }
void adjustForPowerOfTwoHandler(AllreduceDblMsg<DataT>* msg) {
if constexpr (isdebug) {
std::string data(1024, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part1 Handler initialized_ = {}: Received data ({}) "
"from {}\n",
this_node_, initialized_, data, theContext()->getFromNodeCurrentTask());
}

Op<DataT>()(val_, msg->val_);

partTwo();
finished_adjustment_part_ = true;

reduceIter();
}

bool done() { return step_ == num_steps_ and allMessagesReceived(); }
bool isValid() { return (vrt_node_ != -1) and (step_ < num_steps_); }
bool isReady() {
bool allMessagesReceived() {
return std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; });
}
void partTwo() {
if (not isValid() or not isReady()) {
bool isReady() {
return (is_part_of_adjustment_group_ and finished_adjustment_part_) and
step_ == 0 or
allMessagesReceived();
}

void reduceIter() {
// Ensure we have received all necessary messages
if (not isReady()) {
return;
}

Expand All @@ -158,91 +194,122 @@ struct DistanceDoubling {
fmt::print(
"[{}] Part2 Step {}: Sending to Node {} \n", this_node_, step_, dest);
}
if (step_) {
for (int i = 0; i < val_.size(); ++i) {
val_[i] += messages.at(step_ - 1)->val_[i];
}
}

proxy_[dest].template send<&DistanceDoubling::partTwoHandler>(val_, step_);
proxy_[dest].template send<&DistanceDoubling::reduceIterHandler>(
val_, step_);

mask_ <<= 1;
num_send_++;
steps_sent_[step_] = true;
step_++;

if (isReady()) {
partTwo();
tryReduce(step_ - 1);

if (done()) {
finalPart();
} else if (isReady()) {
reduceIter();
}
}

void partTwoHandler(AllreduceDblMsg<DataT>* msg) {
messages.at(msg->step_) = promoteMsg(msg);
void tryReduce(int32_t step) {
if (
(step < step_) and not steps_reduced_[step] and steps_recv_[step] and
std::all_of(
steps_reduced_.cbegin(), steps_reduced_.cbegin() + step,
[](const auto val) { return val; })) {
Op<DataT>()(val_, messages.at(step)->val_);
steps_reduced_[step] = true;
}
}

void reduceIterHandler(AllreduceDblMsg<DataT>* msg) {
if constexpr (isdebug) {
std::string data(1024, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part2 Step {} mask_= {} nprocs_pof2_ = {}: Received data ({}) "
"[{}] Part2 Step {} initialized_ = {} mask_= {} nprocs_pof2_ = {}: "
"Received data ({}) "
"from {}\n",
this_node_, msg->step_, mask_, nprocs_pof2_, data,
this_node_, msg->step_, initialized_, mask_, nprocs_pof2_, data,
theContext()->getFromNodeCurrentTask());
}
steps_recv_[msg->step_] = true;
num_recv_++;
if (mask_ < nprocs_pof2_) {
if (isReady()) {
partTwo();

// Special case when we receive step 2 message before step 1 is done on this node
if (not finished_adjustment_part_) {
if (not initialized_) {
initialize();
}

messages.at(msg->step_) = promoteMsg(msg);
steps_recv_[msg->step_] = true;

return;
}

messages.at(msg->step_) = promoteMsg(msg);
steps_recv_[msg->step_] = true;

tryReduce(msg->step_);

if ((mask_ < nprocs_pof2_) and isReady()) {
reduceIter();

} else if (done()) {
finalPart();
}
}

void partThree() {
void sendToExcludedNodes() {
if (is_part_of_adjustment_group_ and is_even_) {
if constexpr (isdebug) {
fmt::print(
"[{}] Part4 : Sending to Node {} \n", this_node_, this_node_ + 1);
"[{}] Part3 : Sending to Node {} \n", this_node_, this_node_ + 1);
}
proxy_[this_node_ + 1].template send<&DistanceDoubling::partThreeHandler>(
val_);
proxy_[this_node_ + 1]
.template send<&DistanceDoubling::sendToExcludedNodesHandler>(val_);
}
}

void partThreeHandler(AllreduceDblMsg<DataT>* msg) { val_ = msg->val_; }
void sendToExcludedNodesHandler(AllreduceDblMsg<DataT>* msg) {
val_ = msg->val_;

parent_proxy_[this_node_].template invoke<finalHandler>(val_);
}

void finalPart() {
if (vrt_node_ != -1) {
for (int i = 0; i < val_.size(); ++i) {
val_[i] += messages.at(step_ - 1)->val_[i];
}
if (completed_) {
return;
}

if (nprocs_rem_) {
sendToExcludedNodes();
}

parentProxy_[this_node_] .template invoke<finalHandler>(val_);
parent_proxy_[this_node_].template invoke<finalHandler>(val_);
completed_ = true;
}

NodeType this_node_ = {};
uint32_t num_nodes_ = {};
bool is_even_ = false;
vt::objgroup::proxy::Proxy<DistanceDoubling> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parentProxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
DataT val_ = {};
NodeType vrt_node_ = {};
bool initialized_ = false;
bool is_part_of_adjustment_group_ = false;
bool finished_adjustment_part_ = false;
int32_t num_steps_ = {};
int32_t nprocs_pof2_ = {};
int32_t nprocs_rem_ = {};
int32_t mask_ = 1;
bool startedPartThree_ = false;

size_t w_size_ = {};
int32_t step_ = 0;
int32_t num_send_ = 0;
int32_t expected_send_ = 0;
int32_t num_recv_ = 0;
int32_t expected_recv_ = 0;
bool completed_ = false;

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_sent_ = {};
std::vector<bool> steps_reduced_ = {};

std::vector<MsgSharedPtr<AllreduceDblMsg<DataT>>> messages = {};
};
Expand Down
Loading

0 comments on commit 831168f

Please sign in to comment.