Skip to content

Commit

Permalink
#2240: Make sure the order of reduce operations is correct
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Apr 11, 2024
1 parent 7534af9 commit a3d10f6
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 193 deletions.
166 changes: 51 additions & 115 deletions src/vt/collective/reduce/allreduce/distance_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
#include "vt/context/context.h"
#include "vt/messaging/message/message.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"
#include "vt/configs/error/config_assert.h"
#include "vt/messaging/message/smart_ptr.h"

#include <tuple>
#include <cstdint>
Expand Down Expand Up @@ -85,16 +87,20 @@ struct AllreduceDblMsg
int32_t step_ = {};
};

template <typename DataT>
template <typename DataT, typename Op, typename ObjT, auto finalHandler>
struct DistanceDoubling {
void initialize(
const DataT& data, vt::objgroup::proxy::Proxy<DistanceDoubling> proxy,
vt::objgroup::proxy::Proxy<ObjT> parentProxy,
uint32_t num_nodes) {
this_node_ = vt::theContext()->getNode();
is_even_ = this_node_ % 2 == 0;
val_ = data;
proxy_ = proxy;
parentProxy_ = parentProxy;
num_steps_ = static_cast<int32_t>(log2(num_nodes));
messages.resize(num_steps_, nullptr);

nprocs_pof2_ = 1 << num_steps_;
nprocs_rem_ = num_nodes - nprocs_pof2_;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
Expand All @@ -117,24 +123,32 @@ struct DistanceDoubling {
}

void partOne() {
if (is_part_of_adjustment_group_ and is_even_) {
proxy_[this_node_ + 1].template send<&DistanceDoubling::partOneHandler>(
if (not nprocs_rem_) {
// we're running on power of 2 number of nodes, proceed to second step
partTwo();
} else if (is_part_of_adjustment_group_ and not is_even_) {
proxy_[this_node_ - 1].template send<&DistanceDoubling::partOneHandler>(
val_);
}
}

void partOneHandler(AllreduceDblMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[i] += msg->val_[i];
}
Op(val_, msg->val_);
// for (int i = 0; i < msg->val_.size(); i++) {
// val_[i] += msg->val_[i];
// }

partTwo();
}

bool isValid() { return (vrt_node_ != -1) and (step_ < num_steps_); }
bool isReady() {
return std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; });
}
void partTwo() {
if (
vrt_node_ == -1 or (step_ >= num_steps_) or
(not std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; }))) {
if (not isValid() or not isReady()) {
return;
}

Expand All @@ -144,153 +158,73 @@ struct DistanceDoubling {
fmt::print(
"[{}] Part2 Step {}: Sending to Node {} \n", this_node_, step_, dest);
}
if (step_) {
for (int i = 0; i < val_.size(); ++i) {
val_[i] += messages.at(step_ - 1)->val_[i];
}
}

proxy_[dest].template send<&DistanceDoubling::partTwoHandler>(val_, step_);

mask_ <<= 1;
num_send_++;
steps_sent_[step_] = true;
step_++;

if (std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; })) {
if (isReady()) {
partTwo();
}
}

void partTwoHandler(AllreduceDblMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[i] += msg->val_[i];
}
messages.at(msg->step_) = promoteMsg(msg);

if constexpr (isdebug) {
std::string data(128, 0x0);
std::string data(1024, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part2 Step {} mask_= {} nprocs_pof2_ = {}: Received data ({}) "
"idx = {} from {}\n",
"from {}\n",
this_node_, msg->step_, mask_, nprocs_pof2_, data,
theContext()->getFromNodeCurrentTask());
}
steps_recv_[msg->step_] = true;
num_recv_++;
if (mask_ < nprocs_pof2_) {
if (std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; })) {
if (isReady()) {
partTwo();
}
} else {
// step_ = num_steps_ - 1;
// mask_ = nprocs_pof2_ >> 1;
// partThree();
}
}

void partThree() {
if (
vrt_node_ == -1 or
(not std::all_of(
steps_recv_.cbegin() + step_ + 1, steps_recv_.cend(),
[](const auto val) { return val; }))) {
return;
}

if (not startedPartThree_) {
step_ = num_steps_ - 1;
mask_ = nprocs_pof2_ >> 1;
num_send_ = 0;
num_recv_ = 0;
startedPartThree_ = true;
std::fill(steps_sent_.begin(), steps_sent_.end(), false);
std::fill(steps_recv_.begin(), steps_recv_.end(), false);
}

auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

if constexpr (isdebug) {
std::string data(1024, 0x0);

for (auto val : val_) {
data.append(fmt::format("{} ", val));
}

fmt::print(
"[{}] Part3 Step {}: Sending to Node {} data={} \n", this_node_, step_,
dest, data);
}
proxy_[dest].template send<&DistanceDoubling::partThreeHandler>(
val_, step_);

steps_sent_[step_] = true;
num_send_++;
mask_ >>= 1;
step_--;
if (
step_ >= 0 and
std::all_of(
steps_recv_.cbegin() + step_ + 1, steps_recv_.cend(),
[](const auto val) { return val; })) {
partThree();
}
}

void partThreeHandler(AllreduceDblMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[i] = msg->val_[i];
}

if (not startedPartThree_) {
step_ = num_steps_ - 1;
mask_ = nprocs_pof2_ >> 1;
num_send_ = 0;
num_recv_ = 0;
startedPartThree_ = true;
std::fill(steps_sent_.begin(), steps_sent_.end(), false);
std::fill(steps_recv_.begin(), steps_recv_.end(), false);
}

num_recv_++;
if constexpr (isdebug) {
std::string data(128, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part3 Step {}: Received data ({}) from {}\n", this_node_,
msg->step_, data, theContext()->getFromNodeCurrentTask());
}

steps_recv_[msg->step_] = true;

if (
mask_ > 0 and
((step_ == num_steps_ - 1) or
std::all_of(
steps_recv_.cbegin() + step_ + 1, steps_recv_.cend(),
[](const auto val) { return val; }))) {
partThree();
}
}

void partFour() {
if (is_part_of_adjustment_group_ and is_even_) {
if constexpr (isdebug) {
fmt::print(
"[{}] Part4 : Sending to Node {} \n", this_node_, this_node_ + 1);
}
proxy_[this_node_ + 1].template send<&DistanceDoubling::partFourHandler>(
proxy_[this_node_ + 1].template send<&DistanceDoubling::partThreeHandler>(
val_);
}
}

void partFourHandler(AllreduceDblMsg<DataT>* msg) { val_ = msg->val_; }
void partThreeHandler(AllreduceDblMsg<DataT>* msg) { val_ = msg->val_; }
void finalPart() {
if (vrt_node_ != -1) {
for (int i = 0; i < val_.size(); ++i) {
val_[i] += messages.at(step_ - 1)->val_[i];
}
}

parentProxy_[this_node_] .template invoke<finalHandler>(val_);
}

NodeType this_node_ = {};
bool is_even_ = false;
vt::objgroup::proxy::Proxy<DistanceDoubling> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parentProxy_ = {};
DataT val_ = {};
NodeType vrt_node_ = {};
bool is_part_of_adjustment_group_ = false;
Expand All @@ -309,6 +243,8 @@ struct DistanceDoubling {

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_sent_ = {};

std::vector<MsgSharedPtr<AllreduceDblMsg<DataT>>> messages = {};
};

} // namespace vt::collective::reduce::allreduce
Expand Down
18 changes: 3 additions & 15 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

namespace vt::collective::reduce::allreduce {

constexpr bool debug = true;
constexpr bool debug = false;

template <typename DataT>
struct AllreduceRbnMsg
Expand Down Expand Up @@ -140,23 +140,13 @@ struct Rabenseifner {
}
}

expected_send_ = num_steps_;
expected_recv_ = num_steps_;
steps_sent_.resize(num_steps_, false);
steps_recv_.resize(num_steps_, false);

if constexpr (debug) {
std::string str(1024, 0x0);
for (int i = 0; i < num_steps_; ++i) {
str.append(fmt::format(
"Step{}: send_idx = {} send_count = {} recieve_idx = {} "
"recieve_count "
"= {}\n",
i, s_index_[i], s_count_[i], r_index_[i], r_count_[i]));
}
fmt::print(
"[{}] Initialize with size = {} num_steps {} \n {}", this_node_,
w_size_, num_steps_, str);
"[{}] Initialize with size = {} num_steps {} \n", this_node_,
w_size_, num_steps_);
}
}

Expand Down Expand Up @@ -372,9 +362,7 @@ struct Rabenseifner {
size_t w_size_ = {};
int32_t step_ = 0;
int32_t num_send_ = 0;
int32_t expected_send_ = 0;
int32_t num_recv_ = 0;
int32_t expected_recv_ = 0;

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_sent_ = {};
Expand Down
Loading

0 comments on commit a3d10f6

Please sign in to comment.