Skip to content

Commit

Permalink
#1830: lb: extract helper methods in TemperedLB
Browse files Browse the repository at this point in the history
- extract helper methods and add overloads for TemperedWMin
- remove redundant code
- remove unused member variable
  • Loading branch information
cz4rs committed Feb 28, 2023
1 parent 893ca59 commit 1002ab1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
33 changes: 16 additions & 17 deletions src/vt/vrt/collection/balance/temperedlb/temperedlb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
*/

#include "vt/config.h"
#include "vt/configs/types/types_sentinels.h"
#include "vt/configs/types/types_type.h"
#include "vt/timing/timing.h"
#include "vt/vrt/collection/balance/baselb/baselb.h"
#include "vt/vrt/collection/balance/model/load_model.h"
Expand Down Expand Up @@ -474,6 +476,12 @@ void TemperedLB::runLB(TimeType total_load) {
}
}

void TemperedLB::clearDataStructures() {
underloaded_.clear();
load_info_.clear();
is_overloaded_ = is_underloaded_ = false;
}

void TemperedLB::doLBStages(TimeType start_imb) {
decltype(this->cur_objs_) best_objs;
LoadType best_load = 0;
Expand All @@ -483,11 +491,7 @@ void TemperedLB::doLBStages(TimeType start_imb) {
auto this_node = theContext()->getNode();

for (trial_ = 0; trial_ < num_trials_; ++trial_) {
// Clear out data structures
selected_.clear();
underloaded_.clear();
load_info_.clear();
is_overloaded_ = is_underloaded_ = false;
clearDataStructures();

TimeType best_imb_this_trial = start_imb + 10;

Expand All @@ -504,11 +508,7 @@ void TemperedLB::doLBStages(TimeType start_imb) {
}
this_new_load_ = this_load;
} else {
// Clear out data structures from previous iteration
selected_.clear();
underloaded_.clear();
load_info_.clear();
is_overloaded_ = is_underloaded_ = false;
clearDataStructures();
}

vt_debug_print(
Expand Down Expand Up @@ -667,7 +667,7 @@ void TemperedLB::informAsync() {
vtAssert(k_max_ > 0, "Number of rounds (k) must be greater than zero");

auto const this_node = theContext()->getNode();
if (is_underloaded_) {
if (canPropagate()) {
underloaded_.insert(this_node);
}

Expand All @@ -682,7 +682,7 @@ void TemperedLB::informAsync() {
auto propagate_epoch = theTerm()->makeEpochCollective("TemperedLB: informAsync");

// Underloaded start the round
if (is_underloaded_) {
if (canPropagate()) {
uint8_t k_cur_async = 0;
propagateRound(k_cur_async, false, propagate_epoch);
}
Expand Down Expand Up @@ -718,11 +718,11 @@ void TemperedLB::informSync() {
vtAssert(k_max_ > 0, "Number of rounds (k) must be greater than zero");

auto const this_node = theContext()->getNode();
if (is_underloaded_) {
if (canPropagate()) {
underloaded_.insert(this_node);
}

auto propagate_this_round = is_underloaded_;
auto propagate_this_round = canPropagate();
propagate_next_round_ = false;
new_underloaded_ = underloaded_;
new_load_info_ = load_info_;
Expand Down Expand Up @@ -793,8 +793,7 @@ void TemperedLB::propagateRound(uint8_t k_cur, bool sync, EpochType epoch) {
gen_propagate_.seed(seed_());
}

auto& selected = selected_;
selected = underloaded_;
auto& selected = underloaded_;
if (selected.find(this_node) == selected.end()) {
selected.insert(this_node);
}
Expand Down Expand Up @@ -1203,7 +1202,7 @@ void TemperedLB::decide() {

int n_transfers = 0, n_rejected = 0;

if (is_overloaded_) {
if (canMigrate()) {
std::vector<NodeType> under = makeUnderloaded();
std::unordered_map<NodeType, ObjsType> migrate_objs;

Expand Down
10 changes: 9 additions & 1 deletion src/vt/vrt/collection/balance/temperedlb/temperedlb.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ struct TemperedLB : BaseLB {
void informSync();
void decide();
void migrate();
void clearDataStructures();

virtual bool canMigrate() const { return is_overloaded_; }
/**
* \brief Decides whether the rank can initiate information propagation stage
*
* TemperedLB restricts this to underloaded ranks
*/
virtual bool canPropagate() const { return is_underloaded_; }

void propagateRound(uint8_t k_cur_async, bool sync, EpochType epoch = no_epoch);
void propagateIncomingAsync(LoadMsgAsync* msg);
Expand Down Expand Up @@ -164,7 +173,6 @@ struct TemperedLB : BaseLB {
objgroup::proxy::Proxy<TemperedLB> proxy_ = {};
bool is_overloaded_ = false;
bool is_underloaded_ = false;
std::unordered_set<NodeType> selected_ = {};
std::unordered_set<NodeType> underloaded_ = {};
std::unordered_set<NodeType> new_underloaded_ = {};
std::unordered_map<ObjIDType, TimeType> cur_objs_ = {};
Expand Down
5 changes: 5 additions & 0 deletions src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ struct TemperedWMin : TemperedLB {
protected:
TimeType getModeledValue(const elm::ElementIDStruct& obj) override;

/**
* All ranks are allowed to initiate the information propagation stage
*/
bool canPropagate() const override { return true; }

private:
std::shared_ptr<balance::LoadModel> total_work_model_ = nullptr;
balance::LoadModel* load_model_ptr = nullptr;
Expand Down

0 comments on commit 1002ab1

Please sign in to comment.