Skip to content

Commit

Permalink
#2281: Resolve issue with incorrect index generated by StateHolder::g…
Browse files Browse the repository at this point in the history
…etNextID
  • Loading branch information
JacobDomagala committed Oct 11, 2024
1 parent ed1f880 commit 164328c
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/vt/collective/reduce/allreduce/allreduce_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
#include "vt/configs/types/types_type.h"
#include "vt/collective/reduce/allreduce/type.h"
#include "vt/collective/reduce/scoping/strong_types.h"
#include "vt/configs/types/types_sentinels.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"

#include <unordered_map>
Expand All @@ -56,6 +55,7 @@ namespace vt::collective::reduce::allreduce {

struct Rabenseifner;
struct RecursiveDoubling;

struct AllreduceHolder {
using RabenseifnerProxy = ObjGroupProxyType;
using RecursiveDoublingProxy = ObjGroupProxyType;
Expand Down
9 changes: 0 additions & 9 deletions src/vt/collective/reduce/allreduce/rabenseifner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,4 @@ void Rabenseifner::initializeVrtNode() {
}
}

Rabenseifner::~Rabenseifner() {
if (info_.first == ComponentT::ObjGroup) {
StateHolder::clearAll(detail::StrongObjGroup{info_.second});
AllreduceHolder::remove(detail::StrongObjGroup{info_.second});
} else if(info_.first == ComponentT::Group){
StateHolder::clearAll(detail::StrongGroup{info_.second});
}
}

} // namespace vt::collective::reduce::allreduce
2 changes: 0 additions & 2 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ struct Rabenseifner {

void initializeVrtNode();

~Rabenseifner();

/**
* \brief Set final handler that will be executed with allreduce result
*
Expand Down
9 changes: 0 additions & 9 deletions src/vt/collective/reduce/allreduce/recursive_doubling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,4 @@ void RecursiveDoubling::initializeVrtNode() {
}
}

RecursiveDoubling::~RecursiveDoubling() {
if (info_.first == ComponentT::ObjGroup) {
StateHolder::clearAll(detail::StrongObjGroup{info_.second});
AllreduceHolder::remove(detail::StrongObjGroup{info_.second});
} else if(info_.first == ComponentT::Group){
StateHolder::clearAll(detail::StrongGroup{info_.second});
}
}

} // namespace vt::collective::reduce::allreduce
2 changes: 0 additions & 2 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ struct RecursiveDoubling {
*/
void initializeVrtNode();

~RecursiveDoubling();

/**
* \brief Execute the final handler callback with the reduced result.
*
Expand Down
2 changes: 0 additions & 2 deletions src/vt/collective/reduce/allreduce/recursive_doubling.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ template <typename DataT, template <typename Arg> class Op>
template <typename DataT, template <typename Arg> class Op>
void RecursiveDoubling::adjustForPowerOfTwoHan(
RecursiveDoublingMsg<DataT>* msg) {
using DataType = DataHandler<DataT>;
auto& state = getState<RecursiveDoublingT, DataT>(info_, msg->id_);
if (not state.value_assigned_) {
if (not state.initialized_) {
Expand Down Expand Up @@ -311,7 +310,6 @@ RecursiveDoubling::reduceIterHandler(RecursiveDoublingMsg<DataT>* msg) {

template <typename DataT, template <typename Arg> class Op>
void RecursiveDoubling::reduceIterHan(RecursiveDoublingMsg<DataT>* msg) {
using DataType = DataHandler<DataT>;
auto& state = getState<RecursiveDoublingT, DataT>(info_, msg->id_);

if (not state.value_assigned_) {
Expand Down
34 changes: 22 additions & 12 deletions src/vt/collective/reduce/allreduce/state_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ size_t
getNextIdImpl(StateHolder::StatesVec& states, size_t idx) {
size_t id = u64empty;

vt_debug_print(
terse, allreduce, "getNextIdImpl idx={} size={} \n", idx, states.size());

for (; idx < states.size(); ++idx) {
auto& state = states.at(idx);
if (not state or not state->active_) {
Expand All @@ -64,28 +67,35 @@ getNextIdImpl(StateHolder::StatesVec& states, size_t idx) {
id = states.size();
}


return id;
}

size_t StateHolder::getNextID(detail::StrongVrtProxy proxy) {
auto& states = active_coll_states_[proxy.get()];
auto& [idx, states] = active_coll_states_[proxy.get()];

auto current_idx = getNextIdImpl(states, idx);
idx = current_idx + 1;

collection_idx_ = getNextIdImpl(states, collection_idx_);
return collection_idx_;
return current_idx;
}

size_t StateHolder::getNextID(detail::StrongObjGroup proxy) {
auto& states = active_obj_states_[proxy.get()];
auto& [idx, states] = active_obj_states_[proxy.get()];

objgroup_idx_ = getNextIdImpl(states, objgroup_idx_);
return objgroup_idx_;
auto current_idx = getNextIdImpl(states, idx);
idx = current_idx + 1;

return current_idx;
}

size_t StateHolder::getNextID(detail::StrongGroup group) {
auto& states = active_grp_states_[group.get()];
auto& [idx, states] = active_grp_states_[group.get()];

auto current_idx = getNextIdImpl(states, idx);

group_idx_ = getNextIdImpl(states, group_idx_);
return group_idx_;
idx = current_idx + 1;
return current_idx;
}

static inline void
Expand All @@ -101,19 +111,19 @@ clearSingleImpl(StateHolder::StatesVec& states, size_t idx) {
}

void StateHolder::clearSingle(detail::StrongVrtProxy proxy, size_t idx) {
auto& states = active_coll_states_[proxy.get()];
auto& [_, states] = active_coll_states_[proxy.get()];

clearSingleImpl(states, idx);
}

void StateHolder::clearSingle(detail::StrongObjGroup proxy, size_t idx) {
auto& states = active_obj_states_[proxy.get()];
auto& [_, states] = active_obj_states_[proxy.get()];

clearSingleImpl(states, idx);
}

void StateHolder::clearSingle(detail::StrongGroup group, size_t idx) {
auto& states = active_grp_states_[group.get()];
auto& [_, states] = active_grp_states_[group.get()];

clearSingleImpl(states, idx);
}
Expand Down
11 changes: 4 additions & 7 deletions src/vt/collective/reduce/allreduce/state_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace vt::collective::reduce::allreduce {

struct StateHolder {
using StatesVec = std::vector<std::unique_ptr<StateBase>>;
using StatesInfo = std::pair<size_t, StatesVec>;

template <
typename ReducerT, typename DataT,
Expand Down Expand Up @@ -86,17 +87,13 @@ struct StateHolder {
static void clearAll(detail::StrongGroup group);

private:
static inline size_t collection_idx_ = 0;
static inline size_t objgroup_idx_ = 0;
static inline size_t group_idx_ = 0;

static inline std::unordered_map<VirtualProxyType, StatesVec>
static inline std::unordered_map<VirtualProxyType, StatesInfo>
active_coll_states_ = {};

static inline std::unordered_map<ObjGroupProxyType, StatesVec>
static inline std::unordered_map<ObjGroupProxyType, StatesInfo>
active_obj_states_ = {};

static inline std::unordered_map<GroupType, StatesVec> active_grp_states_ =
static inline std::unordered_map<GroupType, StatesInfo> active_grp_states_ =
{};
};

Expand Down
3 changes: 1 addition & 2 deletions src/vt/collective/reduce/allreduce/state_holder.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
//@HEADER
*/

#include "vt/collective/reduce/allreduce/state.h"
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H

Expand Down Expand Up @@ -80,7 +79,7 @@ template <
typename Scalar = typename DataHandler<DataT>::Scalar, typename ProxyT,
typename MapT>
static auto& getStateImpl(ProxyT proxy, MapT& states_map, size_t idx) {
auto& states = states_map[proxy.get()];
auto& [_, states] = states_map[proxy.get()];
auto const num_states = states.size();

if (idx >= num_states || num_states == 0) {
Expand Down
6 changes: 5 additions & 1 deletion src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
#include "vt/collective/reduce/allreduce/type.h"
#include "vt/collective/reduce/allreduce/helpers.h"
#include "vt/collective/reduce/scoping/strong_types.h"
#include "vt/collective/reduce/allreduce/state_holder.h"
#include "vt/collective/reduce/allreduce/allreduce_holder.h"
#include "vt/pipe/pipe_manager.h"

#include <utility>
Expand Down Expand Up @@ -147,6 +147,10 @@ void ObjGroupManager::destroyCollective(ProxyType<ObjT> proxy) {
if (label_iter != labels_.end()) {
labels_.erase(label_iter);
}

vt::reduce::allreduce::AllreduceHolder::remove(
vt::reduce::detail::StrongObjGroup{proxy.getProxy()}
);
}

template <typename ObjT>
Expand Down
3 changes: 0 additions & 3 deletions src/vt/vrt/collection/holders/typeless_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ void TypelessHolder::destroyCollection(VirtualProxyType const proxy) {
}
}

vt::collective::reduce::allreduce::StateHolder::clearAll(
vt::collective::reduce::detail::StrongVrtProxy{proxy});

vt::collective::reduce::allreduce::AllreduceHolder::remove(
vt::collective::reduce::detail::StrongVrtProxy{proxy}
);
Expand Down

0 comments on commit 164328c

Please sign in to comment.