Skip to content

Commit

Permalink
py_ functions should take and return nb objects
Browse files Browse the repository at this point in the history
  • Loading branch information
alkino committed Nov 21, 2024
1 parent d910b74 commit 08d8f26
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions src/nrnpython/nrnpy_p2h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,9 @@ static PyObject* char2pylist(char* buf, int np, int* cnt, int* displ) {
}

#if NRNMPI
static PyObject* py_allgather(PyObject* psrc) {
static nb::object py_allgather(nb::handle psrc) {
int np = nrnmpi_numprocs;
auto sbuf = pickle(psrc);
auto sbuf = pickle(psrc.ptr());
// what are the counts from each rank
int* rcnt = new int[np];
rcnt[nrnmpi_myid] = static_cast<int>(sbuf.size());
Expand All @@ -699,16 +699,16 @@ static PyObject* py_allgather(PyObject* psrc) {

nrnmpi_char_allgatherv(sbuf.data(), rbuf, rcnt, rdispl);

PyObject* pdest = char2pylist(rbuf, np, rcnt, rdispl);
nb::object pdest = nb::steal(char2pylist(rbuf, np, rcnt, rdispl));
delete[] rbuf;
delete[] rcnt;
delete[] rdispl;
return pdest;
}

static PyObject* py_gather(PyObject* psrc, int root) {
static nb::object py_gather(nb::handle psrc, int root) {
int np = nrnmpi_numprocs;
auto sbuf = pickle(psrc);
auto sbuf = pickle(psrc.ptr());
// what are the counts from each rank
int scnt = static_cast<int>(sbuf.size());
int* rcnt = NULL;
Expand All @@ -725,37 +725,32 @@ static PyObject* py_gather(PyObject* psrc, int root) {

nrnmpi_char_gatherv(sbuf.data(), scnt, rbuf, rcnt, rdispl, root);

PyObject* pdest = Py_None;
nb::object pdest = nb::none();
if (root == nrnmpi_myid) {
pdest = char2pylist(rbuf, np, rcnt, rdispl);
pdest = std::steal(char2pylist(rbuf, np, rcnt, rdispl));

Check failure on line 730 in src/nrnpython/nrnpy_p2h.cpp

View workflow job for this annotation

GitHub Actions / ubuntu-22.04 - cmake (-DNRN_ENABLE_CORENEURON=ON -DNRN_ENABLE_INTERVIEWS=OFF -DNMODL_SANITIZERS=undefinedundefined)

no member named 'steal' in namespace 'std'; did you mean 'nanobind::steal'?
delete[] rbuf;
delete[] rcnt;
delete[] rdispl;
} else {
Py_INCREF(pdest);
}
return pdest;
}

static PyObject* py_broadcast(PyObject* psrc, int root) {
static nb::object py_broadcast(nb::handle psrc, int root) {
// Note: root returns reffed psrc.
std::vector<char> buf{};
int cnt = 0;
if (root == nrnmpi_myid) {
buf = pickle(psrc);
buf = pickle(psrc.ptr());
cnt = static_cast<int>(buf.size());
}
nrnmpi_int_broadcast(&cnt, 1, root);
if (root != nrnmpi_myid) {
buf.resize(cnt);
}
nrnmpi_char_broadcast(buf.data(), cnt, root);
PyObject* pdest = psrc;
nb::object pdest = psrc;

Check failure on line 751 in src/nrnpython/nrnpy_p2h.cpp

View workflow job for this annotation

GitHub Actions / ubuntu-22.04 - cmake (-DNRN_ENABLE_CORENEURON=ON -DNRN_ENABLE_INTERVIEWS=OFF -DNMODL_SANITIZERS=undefinedundefined)

no viable conversion from 'nb::handle' to 'nb::object'
if (root != nrnmpi_myid) {
nb::object po = unpickle(buf);
pdest = po.release().ptr();
} else {
Py_INCREF(pdest);
pdest = unpickle(buf);
}
return pdest;
}
Expand Down Expand Up @@ -820,16 +815,16 @@ static Object* py_alltoall_type(int size, int type) {
nb::object pdest;

if (type == 2) {
pdest = nb::steal(py_allgather(psrc.ptr()));
pdest = py_allgather(psrc);
} else if (type != 1 && type != 5) {
root = size;
if (root < 0 || root >= np) {
hoc_execerror("root rank must be >= 0 and < nhost", 0);
}
if (type == 3) {
pdest = nb::steal(py_gather(psrc.ptr(), root));
pdest = py_gather(psrc, root);
} else if (type == 4) {
pdest = nb::steal(py_broadcast(psrc.ptr(), root));
pdest = py_broadcast(psrc, root);
}
} else {
if (type == 5) { // scatter
Expand Down

0 comments on commit 08d8f26

Please sign in to comment.