Skip to content

Commit

Permalink
Use std::vector in nrnpy_p2h.cpp. (#3227)
Browse files Browse the repository at this point in the history
Removes numerous cases of manual memory management
in favour of `std::vector` (because it's RAII).
  • Loading branch information
1uc authored Nov 22, 2024
1 parent 75eb334 commit c9e0780
Showing 1 changed file with 33 additions and 57 deletions.
90 changes: 33 additions & 57 deletions src/nrnpython/nrnpy_p2h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@ std::vector<char> call_picklef(const std::vector<char>& fname, int narg) {

#include "nrnmpi.h"

int* mk_displ(int* cnts) {
int* displ = new int[nrnmpi_numprocs + 1];
static std::vector<int> mk_displ(int* cnts) {
std::vector<int> displ(nrnmpi_numprocs + 1);
displ[0] = 0;
for (int i = 0; i < nrnmpi_numprocs; ++i) {
displ[i + 1] = displ[i] + cnts[i];
Expand Down Expand Up @@ -680,18 +680,15 @@ static PyObject* py_allgather(PyObject* psrc) {
int np = nrnmpi_numprocs;
auto sbuf = pickle(psrc);
// what are the counts from each rank
int* rcnt = new int[np];
std::vector<int> rcnt(np);
rcnt[nrnmpi_myid] = static_cast<int>(sbuf.size());
nrnmpi_int_allgather_inplace(rcnt, 1);
int* rdispl = mk_displ(rcnt);
char* rbuf = new char[rdispl[np]];
nrnmpi_int_allgather_inplace(rcnt.data(), 1);
auto rdispl = mk_displ(rcnt.data());
std::vector<char> rbuf(rdispl[np]);

nrnmpi_char_allgatherv(sbuf.data(), rbuf, rcnt, rdispl);
nrnmpi_char_allgatherv(sbuf.data(), rbuf.data(), rcnt.data(), rdispl.data());

PyObject* pdest = char2pylist(rbuf, np, rcnt, rdispl);
delete[] rbuf;
delete[] rcnt;
delete[] rdispl;
PyObject* pdest = char2pylist(rbuf.data(), np, rcnt.data(), rdispl.data());
return pdest;
}

Expand All @@ -700,26 +697,23 @@ static PyObject* py_gather(PyObject* psrc, int root) {
auto sbuf = pickle(psrc);
// what are the counts from each rank
int scnt = static_cast<int>(sbuf.size());
int* rcnt = NULL;
std::vector<int> rcnt;
if (root == nrnmpi_myid) {
rcnt = new int[np];
rcnt.resize(np);
}
nrnmpi_int_gather(&scnt, rcnt, 1, root);
int* rdispl = NULL;
char* rbuf = NULL;
nrnmpi_int_gather(&scnt, rcnt.data(), 1, root);
std::vector<int> rdispl;
std::vector<char> rbuf;
if (root == nrnmpi_myid) {
rdispl = mk_displ(rcnt);
rbuf = new char[rdispl[np]];
rdispl = mk_displ(rcnt.data());
rbuf.resize(rdispl[np]);
}

nrnmpi_char_gatherv(sbuf.data(), scnt, rbuf, rcnt, rdispl, root);
nrnmpi_char_gatherv(sbuf.data(), scnt, rbuf.data(), rcnt.data(), rdispl.data(), root);

PyObject* pdest = Py_None;
if (root == nrnmpi_myid) {
pdest = char2pylist(rbuf, np, rcnt, rdispl);
delete[] rbuf;
delete[] rcnt;
delete[] rdispl;
pdest = char2pylist(rbuf.data(), np, rcnt.data(), rdispl.data());
} else {
Py_INCREF(pdest);
}
Expand Down Expand Up @@ -828,9 +822,6 @@ static Object* py_alltoall_type(int size, int type) {

std::vector<char> s{};
std::vector<int> scnt{};
int* sdispl = NULL;
int* rcnt = NULL;
int* rdispl = NULL;

// setup source buffer for transfer s, scnt, sdispl
// for alltoall, each rank handled identically
Expand Down Expand Up @@ -858,59 +849,44 @@ static Object* py_alltoall_type(int size, int type) {
if (type == 1) { // alltoall

// what are destination counts
int* ones = new int[np];
for (int i = 0; i < np; ++i) {
ones[i] = 1;
}
sdispl = mk_displ(ones);
rcnt = new int[np];
nrnmpi_int_alltoallv(scnt.data(), ones, sdispl, rcnt, ones, sdispl);
delete[] ones;
delete[] sdispl;
std::vector<int> ones(np, 1);
auto sdispl = mk_displ(ones.data());
std::vector<int> rcnt(np);
nrnmpi_int_alltoallv(
scnt.data(), ones.data(), sdispl.data(), rcnt.data(), ones.data(), sdispl.data());

// exchange
sdispl = mk_displ(scnt.data());
rdispl = mk_displ(rcnt);
auto rdispl = mk_displ(rcnt.data());
if (size < 0) {
pdest = nb::make_tuple(sdispl[np], rdispl[np]);
delete[] sdispl;
delete[] rcnt;
delete[] rdispl;
} else {
char* r = new char[rdispl[np] + 1]; // force > 0 for all None case
nrnmpi_char_alltoallv(s.data(), scnt.data(), sdispl, r, rcnt, rdispl);
delete[] sdispl;

pdest = nb::steal(char2pylist(r, np, rcnt, rdispl));
std::vector<char> r(rdispl[np] + 1); // force > 0 for all None case
nrnmpi_char_alltoallv(
s.data(), scnt.data(), sdispl.data(), r.data(), rcnt.data(), rdispl.data());

delete[] r;
delete[] rcnt;
delete[] rdispl;
pdest = nb::steal(char2pylist(r.data(), np, rcnt.data(), rdispl.data()));
}

} else { // scatter

// destination counts
rcnt = new int[1];
nrnmpi_int_scatter(scnt.data(), rcnt, 1, root);
std::vector<char> r(rcnt[0] + 1); // rcnt[0] can be 0
int rcnt = -1;
nrnmpi_int_scatter(scnt.data(), &rcnt, 1, root);
std::vector<char> r(rcnt + 1); // rcnt can be 0
std::vector<int> sdispl;

// exchange
if (nrnmpi_myid == root) {
sdispl = mk_displ(scnt.data());
}
nrnmpi_char_scatterv(s.data(), scnt.data(), sdispl, r.data(), rcnt[0], root);
if (sdispl)
delete[] sdispl;
nrnmpi_char_scatterv(s.data(), scnt.data(), sdispl.data(), r.data(), rcnt, root);

if (rcnt[0]) {
if (rcnt) {
pdest = unpickle(r);
} else {
pdest = nb::none();
}

delete[] rcnt;
assert(rdispl == NULL);
}
}

Expand Down

0 comments on commit c9e0780

Please sign in to comment.