Skip to content

Commit

Permalink
char2pylist return nb::list and take std::vector (#3241)
Browse files Browse the repository at this point in the history
  • Loading branch information
alkino authored Nov 22, 2024
1 parent c9e0780 commit c199d9f
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/nrnpython/nrnpy_p2h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,17 +659,15 @@ static std::vector<int> mk_displ(int* cnts) {
return displ;
}

static PyObject* char2pylist(char* buf, int np, int* cnt, int* displ) {
PyObject* plist = PyList_New(np);
assert(plist != NULL);
for (int i = 0; i < np; ++i) {
static nb::list char2pylist(const std::vector<char>& buf,
const std::vector<int>& cnt,
const std::vector<int>& displ) {
nb::list plist{};
for (int i = 0; i < cnt.size(); ++i) {
if (cnt[i] == 0) {
Py_INCREF(Py_None); // 'Fatal Python error: deallocating None' eventually
PyList_SetItem(plist, i, Py_None);
plist.append(nb::none());
} else {
nb::object po = unpickle(buf + displ[i], cnt[i]);
PyObject* p = po.release().ptr();
PyList_SetItem(plist, i, p);
plist.append(unpickle(buf.data() + displ[i], cnt[i]));
}
}
return plist;
Expand All @@ -688,8 +686,7 @@ static PyObject* py_allgather(PyObject* psrc) {

nrnmpi_char_allgatherv(sbuf.data(), rbuf.data(), rcnt.data(), rdispl.data());

PyObject* pdest = char2pylist(rbuf.data(), np, rcnt.data(), rdispl.data());
return pdest;
return char2pylist(rbuf, rcnt, rdispl).release().ptr();
}

static PyObject* py_gather(PyObject* psrc, int root) {
Expand All @@ -713,7 +710,7 @@ static PyObject* py_gather(PyObject* psrc, int root) {

PyObject* pdest = Py_None;
if (root == nrnmpi_myid) {
pdest = char2pylist(rbuf.data(), np, rcnt.data(), rdispl.data());
pdest = char2pylist(rbuf, rcnt, rdispl).release().ptr();
} else {
Py_INCREF(pdest);
}
Expand Down Expand Up @@ -865,7 +862,7 @@ static Object* py_alltoall_type(int size, int type) {
nrnmpi_char_alltoallv(
s.data(), scnt.data(), sdispl.data(), r.data(), rcnt.data(), rdispl.data());

pdest = nb::steal(char2pylist(r.data(), np, rcnt.data(), rdispl.data()));
pdest = char2pylist(r, rcnt, rdispl);
}

} else { // scatter
Expand Down

0 comments on commit c199d9f

Please sign in to comment.