Skip to content

Commit

Permalink
Make have_to_want with templates
Browse files Browse the repository at this point in the history
  • Loading branch information
alkino committed Dec 20, 2023
1 parent c483af8 commit 50749cb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 62 deletions.
90 changes: 41 additions & 49 deletions src/nrniv/have2want.cpp → src/nrniv/have2want.hpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
#pragma once

#include <numeric>
#include <vector>

/*
To be included by a file that desires rendezvous rank exchange functionality.
Need to define HAVEWANT_t, HAVEWANT_alltoallv, and HAVEWANT2Int
Need to define HAVEWANT_alltoallv, and HAVEWANT2Int
The latter is a map or unordered_map.
E.g. std::unordered_map<size_t, int>
*/

#ifdef have2want_cpp
#error "This implementation can only be included once"
// The static function names used to involve a macro name (NrnHash) but now,
// with the use of std::..., it may be the case this could be included
// multiple times or even transformed into a template.
#endif

#define have2want_cpp

/*
A rank owns a set of HAVEWANT_t keys and wants information associated with
a set of HAVEWANT_t keys owned by unknown ranks. Owners do not know which
A rank owns a set of T keys and wants information associated with
a set of T keys owned by unknown ranks. Owners do not know which
ranks want their information. Ranks that want info do not know which ranks
own that info.
Expand All @@ -39,21 +34,16 @@ The rendezvous_rank function is used to parallelize this computation
and minimize memory usage so that no single rank ever needs to know all keys.
*/

#ifndef HAVEWANT_t
#define HAVEWANT_t int
#endif

// round robin default rendezvous rank function
static int default_rendezvous(HAVEWANT_t key) {
template <typename T>
int default_rendezvous(const T& key) {
return key % nrnmpi_numprocs;
}

static int* cnt2displ(int* cnt) {
static int* cnt2displ(const int* cnt) {
int* displ = new int[nrnmpi_numprocs + 1];
displ[0] = 0;
for (int i = 0; i < nrnmpi_numprocs; ++i) {
displ[i + 1] = displ[i] + cnt[i];
}
std::partial_sum(cnt, cnt + nrnmpi_numprocs, displ + 1);
return displ;
}

Expand All @@ -63,15 +53,16 @@ static int* srccnt2destcnt(int* srccnt) {
return destcnt;
}

static void rendezvous_rank_get(HAVEWANT_t* data,
template <typename T>
static void rendezvous_rank_get(T* data,
int size,
HAVEWANT_t*& sdata,
T*& sdata,
int*& scnt,
int*& sdispl,
HAVEWANT_t*& rdata,
T*& rdata,
int*& rcnt,
int*& rdispl,
int (*rendezvous_rank)(HAVEWANT_t)) {
int (*rendezvous_rank)(const T&)) {
int nhost = nrnmpi_numprocs;

// count what gets sent
Expand All @@ -87,8 +78,8 @@ static void rendezvous_rank_get(HAVEWANT_t* data,
sdispl = cnt2displ(scnt);
rcnt = srccnt2destcnt(scnt);
rdispl = cnt2displ(rcnt);
sdata = new HAVEWANT_t[sdispl[nhost] + 1]; // ensure not 0 size
rdata = new HAVEWANT_t[rdispl[nhost] + 1]; // ensure not 0 size
sdata = new T[sdispl[nhost] + 1]; // ensure not 0 size
rdata = new T[rdispl[nhost] + 1]; // ensure not 0 size
// scatter data into sdata by recalculating scnt.
for (int i = 0; i < nhost; ++i) {
scnt[i] = 0;
Expand All @@ -101,17 +92,18 @@ static void rendezvous_rank_get(HAVEWANT_t* data,
HAVEWANT_alltoallv(sdata, scnt, sdispl, rdata, rcnt, rdispl);
}

static void have_to_want(HAVEWANT_t* have,
int have_size,
HAVEWANT_t* want,
int want_size,
HAVEWANT_t*& send_to_want,
int*& send_to_want_cnt,
int*& send_to_want_displ,
HAVEWANT_t*& recv_from_have,
int*& recv_from_have_cnt,
int*& recv_from_have_displ,
int (*rendezvous_rank)(HAVEWANT_t)) {
template <typename T = int>
void have_to_want(T* have,
int have_size,
T* want,
int want_size,
T*& send_to_want,
int*& send_to_want_cnt,
int*& send_to_want_displ,
T*& recv_from_have,
int*& recv_from_have_cnt,
int*& recv_from_have_displ,
int (*rendezvous_rank)(const T&)) {
// 1) Send have and want to the rendezvous ranks.
// 2) Rendezvous rank matches have and want.
// 3) Rendezvous ranks tell the want ranks which ranks own the keys
Expand All @@ -120,9 +112,9 @@ static void have_to_want(HAVEWANT_t* have,
int nhost = nrnmpi_numprocs;

// 1) Send have and want to the rendezvous ranks.
HAVEWANT_t *have_s_data, *have_r_data;
T *have_s_data, *have_r_data;
int *have_s_cnt, *have_s_displ, *have_r_cnt, *have_r_displ;
rendezvous_rank_get(have,
rendezvous_rank_get<T>(have,
have_size,
have_s_data,
have_s_cnt,
Expand All @@ -139,7 +131,7 @@ static void have_to_want(HAVEWANT_t* have,
HAVEWANT2Int havekey2rank = HAVEWANT2Int(have_r_displ[nhost] + 1); // ensure not empty.
for (int r = 0; r < nhost; ++r) {
for (int i = 0; i < have_r_cnt[r]; ++i) {
HAVEWANT_t key = have_r_data[have_r_displ[r] + i];
T key = have_r_data[have_r_displ[r] + i];
if (havekey2rank.find(key) != havekey2rank.end()) {
hoc_execerr_ext(
"internal error in have_to_want: key %lld owned by multiple ranks\n",
Expand All @@ -152,9 +144,9 @@ static void have_to_want(HAVEWANT_t* have,
delete[] have_r_cnt;
delete[] have_r_displ;

HAVEWANT_t *want_s_data, *want_r_data;
T *want_s_data, *want_r_data;
int *want_s_cnt, *want_s_displ, *want_r_cnt, *want_r_displ;
rendezvous_rank_get(want,
rendezvous_rank_get<T>(want,
want_size,
want_s_data,
want_s_cnt,
Expand All @@ -173,7 +165,7 @@ static void have_to_want(HAVEWANT_t* have,
for (int r = 0; r < nhost; ++r) {
for (int i = 0; i < want_r_cnt[r]; ++i) {
int ix = want_r_displ[r] + i;
HAVEWANT_t key = want_r_data[ix];
T key = want_r_data[ix];
auto search = havekey2rank.find(key);
if (search == havekey2rank.end()) {
hoc_execerr_ext(
Expand Down Expand Up @@ -223,8 +215,8 @@ static void have_to_want(HAVEWANT_t* have,
for (int i = 0; i < nhost; ++i) {
want_s_cnt[i] = 0;
}
HAVEWANT_t* old_want_s_data = want_s_data;
want_s_data = new HAVEWANT_t[n];
T* old_want_s_data = want_s_data;
want_s_data = new T[n];
// compute the counts
for (int i = 0; i < n; ++i) {
int r = want_s_ownerranks[i];
Expand All @@ -236,15 +228,15 @@ static void have_to_want(HAVEWANT_t* have,
} // recount while filling
for (int i = 0; i < n; ++i) {
int r = want_s_ownerranks[i];
HAVEWANT_t key = old_want_s_data[i];
T key = old_want_s_data[i];
want_s_data[want_s_displ[r] + want_s_cnt[r]] = key;
++want_s_cnt[r];
}
delete[] want_s_ownerranks;
delete[] old_want_s_data;
want_r_cnt = srccnt2destcnt(want_s_cnt);
want_r_displ = cnt2displ(want_r_cnt);
want_r_data = new HAVEWANT_t[want_r_displ[nhost]];
want_r_data = new T[want_r_displ[nhost]];
HAVEWANT_alltoallv(
want_s_data, want_s_cnt, want_s_displ, want_r_data, want_r_cnt, want_r_displ);
// now the want_r_data on the have_ranks are grouped according to the ranks
Expand Down
25 changes: 12 additions & 13 deletions src/nrniv/partrans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,10 @@ static void thread_transfer(NrnThread* _nt) {

// 22-08-2014 For setup of the All2allv pattern, use the rendezvous rank
// idiom.
#define HAVEWANT_t sgid_t
#define HAVEWANT_alltoallv sgid_alltoallv
#define HAVEWANT2Int MapSgid2Int
#if NRNMPI
#include "have2want.cpp"
#include "have2want.hpp"
#endif

void nrnmpi_setup_transfer() {
Expand Down Expand Up @@ -680,17 +679,17 @@ void nrnmpi_setup_transfer() {
sgid_t* recv_from_have;
int *recv_from_have_cnt, *recv_from_have_displ;

have_to_want(ownsrc,
sgids_.size(),
needsrc,
needsrc_cnt,
send_to_want,
send_to_want_cnt,
send_to_want_displ,
recv_from_have,
recv_from_have_cnt,
recv_from_have_displ,
default_rendezvous);
have_to_want<sgid_t>(ownsrc,
sgids_.size(),
needsrc,
needsrc_cnt,
send_to_want,
send_to_want_cnt,
send_to_want_displ,
recv_from_have,
recv_from_have_cnt,
recv_from_have_displ,
default_rendezvous<sgid_t>);

// sanity check. all the sgids we are asked to send, we actually have
int n = send_to_want_displ[nhost];
Expand Down

0 comments on commit 50749cb

Please sign in to comment.