Skip to content

Commit

Permalink
Merge pull request #2210 from DARMA-tasking/2094-allow-bcast-to-selec…
Browse files Browse the repository at this point in the history
…ted-nodes

2094: Allow ObjGroup broadcast to selected nodes
  • Loading branch information
lifflander authored Nov 8, 2023
2 parents dca2a02 + 7fb28eb commit 09159cc
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 9 deletions.
44 changes: 36 additions & 8 deletions examples/hello_world/objgroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,48 @@ struct MyObjGroup {
int main(int argc, char** argv) {
vt::initialize(argc, argv);

vt::NodeType this_node = vt::theContext()->getNode();
vt::NodeType num_nodes = vt::theContext()->getNumNodes();
const auto this_node = vt::theContext()->getNode();
const auto num_nodes = vt::theContext()->getNumNodes();

auto proxy = vt::theObjGroup()->makeCollective<MyObjGroup>(
"examples_hello_world"
);
auto proxy =
vt::theObjGroup()->makeCollective<MyObjGroup>("examples_hello_world");

// Create group of odd nodes and multicast to them (from root node)
vt::theGroup()->newGroupCollective(
this_node % 2, [proxy, this_node](::vt::GroupType type) {
if (this_node == 0) {
proxy.multicast<&MyObjGroup::handler>(type, 122, 244);
}
});

vt::theCollective()->barrier();

if (this_node == 0) {
proxy[0].send<&MyObjGroup::handler>(5,10);
// Send to object 0
proxy[0].send<&MyObjGroup::handler>(5, 10);
if (num_nodes > 1) {
proxy[1].send<&MyObjGroup::handler>(10,20);
// Send to object 1
proxy[1].send<&MyObjGroup::handler>(10, 20);
}
proxy.broadcast<&MyObjGroup::handler>(400,500);

// Broadcast to all nodes
proxy.broadcast<&MyObjGroup::handler>(400, 500);

using namespace ::vt::group::region;

// Create list of nodes and multicast to them
List::ListType range;
for (vt::NodeType node = 0; node < num_nodes; ++node) {
if (node % 2 == 0) {
range.push_back(node);
}
}

proxy.multicast<&MyObjGroup::handler>(
std::make_unique<List>(range), 20, 40
);
}
vt::theCollective()->barrier();

vt::finalize();

Expand Down
10 changes: 10 additions & 0 deletions src/vt/group/group_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ void GroupManager::deleteGroupCollective(GroupType group_id) {
}
}

std::optional<GroupType>
GroupManager::GetTempGroupForRange(const region::Region::ListType& range) {
const auto it = temporary_groups_.find(range);
if (it != temporary_groups_.end()) {
return it->second;
}

return {};
}

bool GroupManager::inGroup(GroupType const group) {
auto iter = local_collective_group_info_.find(group);
vtAssert(iter != local_collective_group_info_.end(), "Must exist");
Expand Down
19 changes: 19 additions & 0 deletions src/vt/group/group_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#include <unordered_map>
#include <cstdlib>
#include <functional>
#include <optional>

#include <mpi.h>

Expand Down Expand Up @@ -120,6 +121,22 @@ struct GroupManager : runtime::component::Component<GroupManager> {
*/
void setupDefaultGroup();

/**
* \internal \brief Cache group created by multicast. This allows for reusing the same group.
*
* \param[in] range list of nodes that are part of given group
* \param[in] group group to cache
*/
void AddNewTempGroup(const region::Region::ListType& range, GroupType group) {
temporary_groups_[range] = group;
}

/**
* \internal \brief Return (if any) group associated with given range of nodes
*/
std::optional<GroupType>
GetTempGroupForRange(const region::Region::ListType& range);

/**
* \brief Create a new rooted group.
*
Expand Down Expand Up @@ -431,6 +448,8 @@ struct GroupManager : runtime::component::Component<GroupManager> {
ActionContainerType continuation_actions_ = {};
ActionListType cleanup_actions_ = {};
CollectiveScopeType collective_scope_;
std::unordered_map<region::Region::ListType, GroupType, region::ListHash>
temporary_groups_ = {};
};

/**
Expand Down
13 changes: 12 additions & 1 deletion src/vt/group/region/group_region.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct Region {
using ListType = std::vector<BoundType>;
using ApplyFnType = std::function<void(RegionUPtrType)>;

virtual ~Region(){}
virtual ~Region() = default;
virtual SizeType getSize() const = 0;
virtual void sort() = 0;
virtual bool contains(NodeType const& node) = 0;
Expand All @@ -77,6 +77,17 @@ struct Region {
virtual void splitN(int nsplits, ApplyFnType apply) const = 0;
};

struct ListHash {
size_t operator()(const Region::ListType& v) const {
std::hash<Region::BoundType> hasher;
size_t seed = 0;
for (const auto i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};

struct List;
struct Range;
struct ShallowList;
Expand Down
21 changes: 21 additions & 0 deletions src/vt/objgroup/proxy/proxy_objgroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#include "vt/rdmahandle/handle_set.fwd.h"
#include "vt/messaging/pending_send.h"
#include "vt/utils/fntraits/fntraits.h"
#include "vt/group/region/group_list.h"

namespace vt { namespace objgroup { namespace proxy {

Expand Down Expand Up @@ -159,6 +160,26 @@ struct Proxy {
template <auto fn, typename... Args>
PendingSendType broadcast(Args&&... args) const;

/**
* \brief Multicast a message to nodes that are part of given group to be delivered to the local object
* instance
*
* \param[in] type group to multicast
* \param[in] args args to pass to the message constructor
*/
template <auto fn, typename... Args>
PendingSendType multicast(GroupType type, Args&&... args) const;

/**
* \brief Multicast a message to nodes specified by the region to be delivered to the local object
* instance
*
* \param[in] nodes region of nodes to multicast to
* \param[in] args args to pass to the message constructor
*/
template <auto fn, typename... Args>
PendingSendType multicast(group::region::Region::RegionUPtrType&& nodes, Args&&... args) const;

/**
* \brief All-reduce back to this objgroup. Performs a reduction using
* operator `Op` followed by a broadcast to `f` with the result.
Expand Down
46 changes: 46 additions & 0 deletions src/vt/objgroup/proxy/proxy_objgroup.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#define INCLUDED_VT_OBJGROUP_PROXY_PROXY_OBJGROUP_IMPL_H

#include "vt/config.h"
#include "vt/group/group_manager.h"
#include "vt/objgroup/common.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"
#include "vt/objgroup/manager.h"
Expand Down Expand Up @@ -108,6 +109,51 @@ Proxy<ObjT>::broadcast(Params&&... params) const {
return typename Proxy<ObjT>::PendingSendType{std::nullptr_t{}};
}

template <typename ObjT>
template <auto f, typename... Params>
typename Proxy<ObjT>::PendingSendType
Proxy<ObjT>::multicast(GroupType type, Params&&... params) const{
using MsgT = typename ObjFuncTraits<decltype(f)>::MsgT;
if constexpr (std::is_same_v<MsgT, NoMsg>) {
using Tuple = typename ObjFuncTraits<decltype(f)>::TupleType;
using SendMsgT = messaging::ParamMsg<Tuple>;
auto msg = vt::makeMessage<SendMsgT>(std::forward<Params>(params)...);
vt::envelopeSetGroup(msg->env, type);
auto const ctrl = proxy::ObjGroupProxy::getID(proxy_);
auto const han = auto_registry::makeAutoHandlerObjGroupParam<
ObjT, decltype(f), f, SendMsgT
>(ctrl);
return theObjGroup()->broadcast(msg, han);
} else {
auto msg = makeMessage<MsgT>(std::forward<Params>(params)...);
vt::envelopeSetGroup(msg->env, type);
return broadcastMsg<MsgT, f>(msg);
}

// Silence nvcc warning (no longer needed for CUDA 11.7 and up)
return typename Proxy<ObjT>::PendingSendType{std::nullptr_t{}};
}

template <typename ObjT>
template <auto f, typename... Params>
typename Proxy<ObjT>::PendingSendType Proxy<ObjT>::multicast(
group::region::Region::RegionUPtrType&& nodes, Params&&... params) const {
vtAssert(
not dynamic_cast<group::region::ShallowList*>(nodes.get()),
"multicast: range of nodes is not supported for ShallowList!"
);

nodes->sort();
auto& range = nodes->makeList();

auto groupID = theGroup()->GetTempGroupForRange(range);
if (!groupID.has_value()) {
groupID = theGroup()->newGroup(std::move(nodes), [](GroupType type) {});
theGroup()->AddNewTempGroup(range, groupID.value());
}

return multicast<f>(groupID.value(), std::forward<Params>(params)...);
}

template <typename ObjT>
template <
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/objgroup/test_objgroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,32 @@ TEST_F(TestObjGroup, test_proxy_invoke) {
EXPECT_EQ(proxy.get()->recv_, 3);
}

TEST_F(TestObjGroup, test_proxy_multicast) {
using namespace ::vt::group::region;
auto const this_node = theContext()->getNode();
auto const num_nodes = theContext()->getNumNodes();

auto proxy =
vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_multicast");

vt::runInEpochCollective([this_node, num_nodes, proxy] {
if (this_node == 0) {
// Create list of nodes and multicast to them
List::ListType range;
for (vt::NodeType node = 0; node < num_nodes; ++node) {
if (node % 2 == 0) {
range.push_back(node);
}
}

proxy.multicast<&MyObjA::handler>(std::make_unique<List>(range));
}
});

const auto expected = this_node % 2 == 0 ? 1 : 0;
EXPECT_EQ(proxy.get()->recv_, expected);
}

TEST_F(TestObjGroup, test_pending_send) {
auto my_node = vt::theContext()->getNode();
// create a proxy to a object group
Expand Down

0 comments on commit 09159cc

Please sign in to comment.