Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1129 improve api for subphases #1805

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions examples/collection/lb_iter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ struct IterCol : vt::Collection<IterCol, vt::Index1D> {
struct IterMsg : vt::CollectionMessage<IterCol> {
IterMsg() = default;
IterMsg(
int64_t const in_work_amt, int64_t const in_iter, int64_t const subphase
int64_t const in_work_amt, int64_t const in_iter
)
: iter_(in_iter), work_amt_(in_work_amt), subphase_(subphase)
: iter_(in_iter), work_amt_(in_work_amt)
{ }

int64_t iter_ = 0;
int64_t work_amt_ = 0;
int64_t subphase_ = 0;
};

void iterWork(IterMsg* msg);
Expand All @@ -77,7 +76,6 @@ struct IterCol : vt::Collection<IterCol, vt::Index1D> {
static double weight = 1.0f;

void IterCol::iterWork(IterMsg* msg) {
this->lb_data_.setSubPhase(msg->subphase_);
double val = 0.1f;
double val2 = 0.4f * msg->work_amt_;
auto const idx = getIndex().x();
Expand Down Expand Up @@ -130,14 +128,14 @@ int main(int argc, char** argv) {
for (int i = 0; i < num_iter; i++) {
auto cur_time = vt::timing::getCurrentTime();

vt::runInEpochCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(10, i, 0);
vt::runSubphaseCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(10, i);
});
vt::runInEpochCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(5, i, 1);
vt::runSubphaseCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(5, i);
});
vt::runInEpochCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(15, i, 2);
vt::runSubphaseCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(15, i);
});

auto total_time = vt::timing::getCurrentTime() - cur_time;
Expand Down
64 changes: 28 additions & 36 deletions src/vt/elm/elm_lb_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#define INCLUDED_VT_ELM_ELM_LB_DATA_CC

#include "vt/elm/elm_lb_data.h"
#include "vt/phase/phase_manager.h"

#include "vt/config.h"

Expand Down Expand Up @@ -87,17 +88,21 @@ void ElementLBData::sendToEntity(
}

void ElementLBData::sendComm(elm::CommKey key, double bytes) {
phase_comm_[cur_phase_][key].sendMsg(bytes);
subphase_comm_[cur_phase_].resize(cur_subphase_ + 1);
subphase_comm_[cur_phase_].at(cur_subphase_)[key].sendMsg(bytes);
auto cur_phase = thePhase()->getCurrentPhase();
auto cur_subphase = thePhase()->getCurrentSubphase();
phase_comm_[cur_phase][key].sendMsg(bytes);
subphase_comm_[cur_phase].resize(cur_subphase + 1);
subphase_comm_[cur_phase].at(cur_subphase)[key].sendMsg(bytes);
}

void ElementLBData::recvComm(
elm::CommKey key, double bytes
) {
phase_comm_[cur_phase_][key].receiveMsg(bytes);
subphase_comm_[cur_phase_].resize(cur_subphase_ + 1);
subphase_comm_[cur_phase_].at(cur_subphase_)[key].receiveMsg(bytes);
auto cur_phase = thePhase()->getCurrentPhase();
auto cur_subphase = thePhase()->getCurrentSubphase();
phase_comm_[cur_phase][key].receiveMsg(bytes);
subphase_comm_[cur_phase].resize(cur_subphase + 1);
subphase_comm_[cur_phase].at(cur_subphase)[key].receiveMsg(bytes);
}

void ElementLBData::recvObjData(
Expand Down Expand Up @@ -125,42 +130,34 @@ void ElementLBData::recvToNode(
}

void ElementLBData::addTime(TimeTypeWrapper const& time) {
phase_timings_[cur_phase_] += time.seconds();
auto cur_phase = thePhase()->getCurrentPhase();
phase_timings_[cur_phase] += time.seconds();

subphase_timings_[cur_phase_].resize(cur_subphase_ + 1);
subphase_timings_[cur_phase_].at(cur_subphase_) += time.seconds();
auto cur_subphase = thePhase()->getCurrentSubphase();
subphase_timings_[cur_phase].resize(cur_subphase + 1);
subphase_timings_[cur_phase].at(cur_subphase) += time.seconds();

vt_debug_print(
verbose,lb,
"ElementLBData: addTime: time={}, cur_load={}\n",
time,
TimeTypeWrapper(phase_timings_[cur_phase_])
TimeTypeWrapper(phase_timings_[cur_phase])
);
}

void ElementLBData::updatePhase(PhaseType const& inc) {
void ElementLBData::updatePhase(PhaseType const& cur_phase) {
vt_debug_print(
verbose, lb,
"ElementLBData: updatePhase: cur_phase_={}, inc={}\n",
cur_phase_, inc
"ElementLBData: updatePhase: new_phase={}\n",
cur_phase
);

cur_phase_ += inc;

// Access all table entries for current phase, to ensure presence even
// if they're left empty
phase_timings_[cur_phase_];
subphase_timings_[cur_phase_];
phase_comm_[cur_phase_];
subphase_comm_[cur_phase_];
}

void ElementLBData::resetPhase() {
cur_phase_ = fst_lb_phase;
}

PhaseType ElementLBData::getPhase() const {
return cur_phase_;
phase_timings_[cur_phase];
subphase_timings_[cur_phase];
phase_comm_[cur_phase];
subphase_comm_[cur_phase];
}

TimeType ElementLBData::getLoad(PhaseType const& phase) const {
Expand Down Expand Up @@ -199,6 +196,8 @@ TimeType ElementLBData::getLoad(PhaseType phase, SubphaseType subphase) const {
}

std::vector<TimeType> const& ElementLBData::getSubphaseTimes(PhaseType phase) {
auto cur_subphase = thePhase()->getCurrentSubphase();
subphase_timings_[phase].resize(cur_subphase + 1);
return subphase_timings_[phase];
}

Expand All @@ -216,6 +215,8 @@ ElementLBData::getComm(PhaseType const& phase) {
}

std::vector<CommMapType> const& ElementLBData::getSubphaseComm(PhaseType phase) {
auto cur_subphase = thePhase()->getCurrentSubphase();
subphase_comm_[phase].resize(cur_subphase + 1);
auto const& subphase_comm = subphase_comm_[phase];

vt_debug_print(
Expand All @@ -227,15 +228,6 @@ std::vector<CommMapType> const& ElementLBData::getSubphaseComm(PhaseType phase)
return subphase_comm;
}

void ElementLBData::setSubPhase(SubphaseType subphase) {
vtAssert(subphase < no_subphase, "subphase must be less than sentinel");
cur_subphase_ = subphase;
}

SubphaseType ElementLBData::getSubPhase() const {
return cur_subphase_;
}

void ElementLBData::releaseLBDataFromUnneededPhases(PhaseType phase, unsigned int look_back) {
if (phase >= look_back) {
phase_timings_.erase(phase - look_back);
Expand Down
10 changes: 1 addition & 9 deletions src/vt/elm/elm_lb_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,13 @@ struct ElementLBData {
NodeType to, ElementIDStruct from_perm,
double bytes, bool bcast
);
void updatePhase(PhaseType const& inc = 1);
void resetPhase();
PhaseType getPhase() const;
void updatePhase(PhaseType const& phase);
TimeType getLoad(PhaseType const& phase) const;
TimeType getLoad(PhaseType phase, SubphaseType subphase) const;

CommMapType const& getComm(PhaseType const& phase);
std::vector<CommMapType> const& getSubphaseComm(PhaseType phase);
std::vector<TimeType> const& getSubphaseTimes(PhaseType phase);
void setSubPhase(SubphaseType subphase);
SubphaseType getSubPhase() const;

// these are just for unit testing
std::size_t getLoadPhaseCount() const;
Expand All @@ -103,10 +99,8 @@ struct ElementLBData {
void serialize(Serializer& s) {
s | cur_time_started_;
s | cur_time_;
s | cur_phase_;
s | phase_timings_;
s | phase_comm_;
s | cur_subphase_;
s | subphase_timings_;
s | subphase_comm_;
}
Expand All @@ -125,11 +119,9 @@ struct ElementLBData {
protected:
bool cur_time_started_ = false;
TimeType cur_time_ = 0.0;
PhaseType cur_phase_ = fst_lb_phase;
std::unordered_map<PhaseType, TimeType> phase_timings_ = {};
std::unordered_map<PhaseType, CommMapType> phase_comm_ = {};

SubphaseType cur_subphase_ = 0;
std::unordered_map<PhaseType, std::vector<TimeType>> subphase_timings_ = {};
std::unordered_map<PhaseType, std::vector<CommMapType>> subphase_comm_ = {};
};
Expand Down
4 changes: 3 additions & 1 deletion src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ void ActiveMessenger::startup() {
#if vt_check_enabled(lblite)
// Hook to collect LB data about objgroups
thePhase()->registerHookCollective(phase::PhaseHook::End, [this]{
auto const phase = thePhase()->getCurrentPhase();
theNodeLBData()->addNodeLBData(
bare_handler_dummy_elm_id_for_lb_data_, &bare_handler_lb_data_, nullptr
bare_handler_dummy_elm_id_for_lb_data_, &bare_handler_lb_data_, nullptr,
phase
);
});
#endif
Expand Down
49 changes: 49 additions & 0 deletions src/vt/messaging/collection_chain_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,55 @@ class CollectionChainSet final {
return nextStepCollective("", step_action);
}

/**
* \brief The next collective step to execute for each index that is added
* to the CollectionChainSet on each node.
*
* Should be used for steps with internal recursive communication and global
* inter-dependence. Creates a global (on the communicator), collective epoch
* to track all the casually related messages and collectively wait for
* termination of all of the recursive sends. Advances the subphase at
* termination.
*
* \param[in] label Label for the epoch created for debugging
* \param[in] step_action the next step to execute, returning a \c PendingSend
*/
void nextStepCollectiveSubphase(
std::string const& label, std::function<PendingSend(Index)> step_action) {
auto epoch = theTerm()->makeEpochCollective(label);

theTerm()->addActionEpoch(epoch, [=]{
thePhase()->advanceSubphase();
});

vt::theMsg()->pushEpoch(epoch);

for (auto& entry : chains_) {
auto& idx = entry.first;
auto& chain = entry.second;
chain.add(epoch, step_action(idx));
}

vt::theMsg()->popEpoch(epoch);
theTerm()->finishedEpoch(epoch);
}

/**
* \brief The next collective step to execute for each index that is added
* to the CollectionChainSet on each node.
*
* Should be used for steps with internal recursive communication and global
* inter-dependence. Creates a global (on the communicator), collective epoch
* to track all the casually related messages and collectively wait for
* termination of all of the recursive sends. Advances the subphase at
* termination.
*
* \param[in] step_action the next step to execute, returning a \c PendingSend
*/
void nextStepCollectiveSubphase(std::function<PendingSend(Index)> step_action) {
return nextStepCollectiveSubphase("", step_action);
}

/**
* \brief The next collective step of both CollectionChainSets
* to execute over all shared indices of the CollectionChainSets over all
Expand Down
5 changes: 4 additions & 1 deletion src/vt/objgroup/manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void ObjGroupManager::startup() {
#if vt_check_enabled(lblite)
// Hook to collect LB data about objgroups
thePhase()->registerHookCollective(phase::PhaseHook::End, []{
auto const phase = thePhase()->getCurrentPhase();
auto& objs = theObjGroup()->objs_;
for (auto&& obj : objs) {
auto holder = obj.second.get();
Expand All @@ -68,7 +69,9 @@ void ObjGroupManager::startup() {
auto proxy = elm::ElmIDBits::getObjGroupProxy(elm_id.id, false);
vtAssertExpr(proxy == obj.first);
theNodeLBData()->registerObjGroupInfo(elm_id, obj.first);
theNodeLBData()->addNodeLBData(elm_id, &holder->getLBData(), nullptr);
theNodeLBData()->addNodeLBData(
elm_id, &holder->getLBData(), nullptr, phase
);
}
}
});
Expand Down
1 change: 1 addition & 0 deletions src/vt/phase/phase_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ void PhaseManager::nextPhaseCollective() {
runHooks(PhaseHook::EndPostMigration);

cur_phase_++;
cur_subphase_ = 0;

vt_debug_print(
normal, phase,
Expand Down
14 changes: 14 additions & 0 deletions src/vt/phase/phase_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ struct PhaseManager : runtime::component::Component<PhaseManager> {
*/
PhaseType getCurrentPhase() const { return cur_phase_; }

/**
* \brief Get the current subphase
*
* \return the current subphase
*/
SubphaseType getCurrentSubphase() const { return cur_subphase_; }

/**
* \brief Advance subphase
*/
void advanceSubphase() { ++cur_subphase_; }

/**
* \brief Collectively register a phase hook that triggers depending on the
* type of hook
Expand Down Expand Up @@ -200,6 +212,7 @@ struct PhaseManager : runtime::component::Component<PhaseManager> {
template <typename SerializerT>
void serialize(SerializerT& s) {
s | cur_phase_
| cur_subphase_
| proxy_
| collective_hooks_
| rooted_hooks_
Expand All @@ -213,6 +226,7 @@ struct PhaseManager : runtime::component::Component<PhaseManager> {

private:
PhaseType cur_phase_ = 0; /**< Current phase */
SubphaseType cur_subphase_ = 0; /**< Current subphase */
ObjGroupProxyType proxy_ = no_obj_group; /**< Objgroup proxy */
HookIDMapType collective_hooks_; /**< Collective regisstered hooks */
HookIDMapType rooted_hooks_; /**< Rooted regisstered hooks */
Expand Down
6 changes: 6 additions & 0 deletions src/vt/scheduler/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ void runInEpochCollective(Callable&& fn);
template <typename Callable>
void runInEpochCollective(std::string const& label, Callable&& fn);

template <typename Callable>
void runSubphaseCollective(Callable&& fn);

template <typename Callable>
void runSubphaseCollective(std::string const& label, Callable&& fn);

namespace messaging {

template <typename T>
Expand Down
15 changes: 15 additions & 0 deletions src/vt/scheduler/scheduler.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "vt/config.h"
#include "vt/messaging/active.h"
#include "vt/termination/termination.h"
#include "vt/phase/phase_manager.h"

namespace vt {

Expand All @@ -72,6 +73,20 @@ void runInEpochCollective(std::string const& label, Callable&& fn) {
runInEpoch(ep, std::forward<Callable>(fn));
}

template <typename Callable>
void runSubphaseCollective(Callable&& fn) {
runSubphaseCollective("UNLABELED", std::forward<Callable>(fn));
}

template <typename Callable>
void runSubphaseCollective(std::string const& label, Callable&& fn) {
auto ep = theTerm()->makeEpochCollective(label);
theTerm()->addActionEpoch(ep, [=]{
thePhase()->advanceSubphase();
});
runInEpoch(ep, std::forward<Callable>(fn));
}

template <typename Callable>
void runInEpochRooted(Callable&& fn) {
runInEpochRooted("UNLABELED", std::forward<Callable>(fn));
Expand Down
Loading