Skip to content

Commit

Permalink
Modernize DECREF (part 9). (#3234)
Browse files Browse the repository at this point in the history
This commit refactors a part of the NRN Python bindings to use `nanobind`
objects instead of `Py_DECREF`. The purpose is to simplify the DECREFing logic
on error paths; and the risk of leaking when exceptions are thrown.

As part of the refactoring, if needed, the scope of certain variables might be
reduced or a given a new name. Additionally, NULL pointers are replaced with
`nullptr`.

This commit doesn't intentionally change reference counts.
  • Loading branch information
1uc authored Nov 26, 2024
1 parent e08dc8d commit fa7d530
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 143 deletions.
1 change: 1 addition & 0 deletions src/nrnpython/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(NRNPYTHON_FILES_LIST
nrn_metaclass.cpp
nrnpy_nrn.cpp
nrnpy_p2h.cpp
nrnpy_utils.cpp
grids.cpp
rxd.cpp
rxd_extracellular.cpp
Expand Down
7 changes: 3 additions & 4 deletions src/nrnpython/nrnpy_hoc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,10 +2765,9 @@ static Object** gui_helper_(const char* name, Object* obj) {

static Object** vec_as_numpy_helper(int size, double* data) {
if (vec_as_numpy) {
PyObject* po = (*vec_as_numpy)(size, data);
if (po != Py_None) {
Object* ho = nrnpy_po2ho(po);
Py_DECREF(po);
auto po = nb::steal((*vec_as_numpy)(size, data));
if (!po.is_none()) {
Object* ho = nrnpy_po2ho(po.release().ptr());
--ho->refcount;
return hoc_temp_objptr(ho);
}
Expand Down
40 changes: 20 additions & 20 deletions src/nrnpython/nrnpy_nrn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2970,13 +2970,14 @@ static void rangevars_add(Symbol* sym) {
PyDict_SetItemString(rangevars_, sym->name, (PyObject*) r);
}

// Returns a borrowed reference.
PyObject* nrnpy_nrn(void) {
PyObject* m;
nb::object m;

int err = 0;
PyObject* modules = PyImport_GetModuleDict();
if ((m = PyDict_GetItemString(modules, "nrn")) != NULL && PyModule_Check(m)) {
return m;
if ((m = nb::borrow(PyDict_GetItemString(modules, "nrn"))) && PyModule_Check(m.ptr())) {
return m.ptr();
}
psection_type = (PyTypeObject*) PyType_FromSpec(&nrnpy_SectionType_spec);
psection_type->tp_new = PyType_GenericNew;
Expand Down Expand Up @@ -3019,18 +3020,18 @@ PyObject* nrnpy_nrn(void) {
goto fail;
Py_INCREF(opaque_pointer_type);

m = PyModule_Create(&nrnsectionmodule); // like nrn but namespace will not include mechanims.
PyModule_AddObject(m, "Section", (PyObject*) psection_type);
PyModule_AddObject(m, "Segment", (PyObject*) psegment_type);
m = nb::steal(PyModule_Create(&nrnsectionmodule)); // like nrn but namespace will not include
// mechanims.
PyModule_AddObject(m.ptr(), "Section", (PyObject*) psection_type);
PyModule_AddObject(m.ptr(), "Segment", (PyObject*) psegment_type);

err = PyDict_SetItemString(modules, "_neuron_section", m);
err = PyDict_SetItemString(modules, "_neuron_section", m.ptr());
assert(err == 0);
Py_DECREF(m);
m = PyModule_Create(&nrnmodule); //
nrnmodule_ = m;
PyModule_AddObject(m, "Section", (PyObject*) psection_type);
PyModule_AddObject(m, "Segment", (PyObject*) psegment_type);
PyModule_AddObject(m, "OpaquePointer", (PyObject*) opaque_pointer_type);
m = nb::steal(PyModule_Create(&nrnmodule)); //
nrnmodule_ = m.ptr();
PyModule_AddObject(m.ptr(), "Section", (PyObject*) psection_type);
PyModule_AddObject(m.ptr(), "Segment", (PyObject*) psegment_type);
PyModule_AddObject(m.ptr(), "OpaquePointer", (PyObject*) opaque_pointer_type);

pmech_generic_type = (PyTypeObject*) PyType_FromSpec(&nrnpy_MechanismType_spec);
pmechfunc_generic_type = (PyTypeObject*) PyType_FromSpec(&nrnpy_MechFuncType_spec);
Expand All @@ -3052,10 +3053,10 @@ PyObject* nrnpy_nrn(void) {
Py_INCREF(pmechfunc_generic_type);
Py_INCREF(pmech_of_seg_iter_generic_type);
Py_INCREF(pvar_of_mech_iter_generic_type);
PyModule_AddObject(m, "Mechanism", (PyObject*) pmech_generic_type);
PyModule_AddObject(m, "MechFunc", (PyObject*) pmechfunc_generic_type);
PyModule_AddObject(m, "MechOfSegIterator", (PyObject*) pmech_of_seg_iter_generic_type);
PyModule_AddObject(m, "VarOfMechIterator", (PyObject*) pvar_of_mech_iter_generic_type);
PyModule_AddObject(m.ptr(), "Mechanism", (PyObject*) pmech_generic_type);
PyModule_AddObject(m.ptr(), "MechFunc", (PyObject*) pmechfunc_generic_type);
PyModule_AddObject(m.ptr(), "MechOfSegIterator", (PyObject*) pmech_of_seg_iter_generic_type);
PyModule_AddObject(m.ptr(), "VarOfMechIterator", (PyObject*) pvar_of_mech_iter_generic_type);
remake_pmech_types();
nrnpy_reg_mech_p_ = nrnpy_reg_mech;
nrnpy_ob_is_seg = ob_is_seg;
Expand All @@ -3067,10 +3068,9 @@ PyObject* nrnpy_nrn(void) {
nrnpy_pysec_cell_p_ = pysec_cell;
nrnpy_pysec_cell_equals_p_ = pysec_cell_equals;

err = PyDict_SetItemString(modules, "nrn", m);
err = PyDict_SetItemString(modules, "nrn", m.ptr());
assert(err == 0);
Py_DECREF(m);
return m;
return m.ptr();
fail:
return NULL;
}
Expand Down
20 changes: 10 additions & 10 deletions src/nrnpython/nrnpy_p2h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ static nb::list char2pylist(const std::vector<char>& buf,
}

#if NRNMPI
// Returns a new reference.
static PyObject* py_allgather(PyObject* psrc) {
int np = nrnmpi_numprocs;
auto sbuf = pickle(psrc);
Expand All @@ -679,6 +680,7 @@ static PyObject* py_allgather(PyObject* psrc) {
return char2pylist(rbuf, rcnt, rdispl).release().ptr();
}

// Returns a new reference.
static PyObject* py_gather(PyObject* psrc, int root) {
int np = nrnmpi_numprocs;
auto sbuf = pickle(psrc);
Expand All @@ -698,15 +700,14 @@ static PyObject* py_gather(PyObject* psrc, int root) {

nrnmpi_char_gatherv(sbuf.data(), scnt, rbuf.data(), rcnt.data(), rdispl.data(), root);

PyObject* pdest = Py_None;
nb::object pdest = nb::none();
if (root == nrnmpi_myid) {
pdest = char2pylist(rbuf, rcnt, rdispl).release().ptr();
} else {
Py_INCREF(pdest);
pdest = char2pylist(rbuf, rcnt, rdispl);
}
return pdest;
return pdest.release().ptr();
}

// Returns a new reference.
static PyObject* py_broadcast(PyObject* psrc, int root) {
// Note: root returns reffed psrc.
std::vector<char> buf{};
Expand All @@ -720,14 +721,13 @@ static PyObject* py_broadcast(PyObject* psrc, int root) {
buf.resize(cnt);
}
nrnmpi_char_broadcast(buf.data(), cnt, root);
PyObject* pdest = psrc;
nb::object pdest;
if (root != nrnmpi_myid) {
nb::object po = unpickle(buf);
pdest = po.release().ptr();
pdest = unpickle(buf);
} else {
Py_INCREF(pdest);
pdest = nb::borrow(psrc);
}
return pdest;
return pdest.release().ptr();
}
#endif

Expand Down
102 changes: 102 additions & 0 deletions src/nrnpython/nrnpy_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "nrnpy_utils.h"

#include <tuple>
#include <nanobind/nanobind.h>

namespace nb = nanobind;

inline std::tuple<nb::object, nb::object, nb::object> fetch_pyerr() {
PyObject* ptype = NULL;
PyObject* pvalue = NULL;
PyObject* ptraceback = NULL;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);

return std::make_tuple(nb::steal(ptype), nb::steal(pvalue), nb::steal(ptraceback));
}


Py2NRNString::Py2NRNString(PyObject* python_string, bool disable_release) {
disable_release_ = disable_release;
str_ = NULL;
if (PyUnicode_Check(python_string)) {
auto py_bytes = nb::steal(PyUnicode_AsASCIIString(python_string));
if (py_bytes) {
str_ = strdup(PyBytes_AsString(py_bytes.ptr()));
if (!str_) { // errno is ENOMEM
PyErr_SetString(PyExc_MemoryError, "strdup in Py2NRNString");
}
}
} else if (PyBytes_Check(python_string)) {
str_ = strdup(PyBytes_AsString(python_string));
// assert(strlen(str_) == PyBytes_Size(python_string))
// not checking for embedded '\0'
if (!str_) { // errno is ENOMEM
PyErr_SetString(PyExc_MemoryError, "strdup in Py2NRNString");
}
} else { // Neither Unicode or PyBytes
PyErr_SetString(PyExc_TypeError, "Neither Unicode or PyBytes");
}
}

void Py2NRNString::set_pyerr(PyObject* type, const char* message) {
nb::object err_type;
nb::object err_value;
nb::object err_traceback;

if (err()) {
std::tie(err_type, err_value, err_traceback) = fetch_pyerr();
}
if (err_value && err_type) {
auto umes = nb::steal(
PyUnicode_FromFormat("%s (Note: %S: %S)", message, err_type.ptr(), err_value.ptr()));
PyErr_SetObject(type, umes.ptr());
} else {
PyErr_SetString(type, message);
}
}

char* Py2NRNString::get_pyerr() {
if (err()) {
auto [ptype, pvalue, ptraceback] = fetch_pyerr();
if (pvalue) {
auto pstr = nb::steal(PyObject_Str(pvalue.ptr()));
if (pstr) {
const char* err_msg = PyUnicode_AsUTF8(pstr.ptr());
if (err_msg) {
str_ = strdup(err_msg);
} else {
str_ = strdup("get_pyerr failed at PyUnicode_AsUTF8");
}
} else {
str_ = strdup("get_pyerr failed at PyObject_Str");
}
} else {
str_ = strdup("get_pyerr failed at PyErr_Fetch");
}
}
PyErr_Clear(); // in case could not turn pvalue into c_str.
return str_;
}

char* PyErr2NRNString::get_pyerr() {
if (PyErr_Occurred()) {
auto [ptype, pvalue, ptraceback] = fetch_pyerr();
if (pvalue) {
auto pstr = nb::steal(PyObject_Str(pvalue.ptr()));
if (pstr) {
const char* err_msg = PyUnicode_AsUTF8(pstr.ptr());
if (err_msg) {
str_ = strdup(err_msg);
} else {
str_ = strdup("get_pyerr failed at PyUnicode_AsUTF8");
}
} else {
str_ = strdup("get_pyerr failed at PyObject_Str");
}
} else {
str_ = strdup("get_pyerr failed at PyErr_Fetch");
}
}
PyErr_Clear(); // in case could not turn pvalue into c_str.
return str_;
}
110 changes: 9 additions & 101 deletions src/nrnpython/nrnpy_utils.h
Original file line number Diff line number Diff line change
@@ -1,37 +1,17 @@
#pragma once

#include "nrnwrap_Python.h"
#include "nrn_export.hpp"
#include <cassert>


inline bool is_python_string(PyObject* python_string) {
return PyUnicode_Check(python_string) || PyBytes_Check(python_string);
}

class Py2NRNString {
class NRN_EXPORT Py2NRNString {
public:
Py2NRNString(PyObject* python_string, bool disable_release = false) {
disable_release_ = disable_release;
str_ = NULL;
if (PyUnicode_Check(python_string)) {
PyObject* py_bytes = PyUnicode_AsASCIIString(python_string);
if (py_bytes) {
str_ = strdup(PyBytes_AsString(py_bytes));
if (!str_) { // errno is ENOMEM
PyErr_SetString(PyExc_MemoryError, "strdup in Py2NRNString");
}
}
Py_XDECREF(py_bytes);
} else if (PyBytes_Check(python_string)) {
str_ = strdup(PyBytes_AsString(python_string));
// assert(strlen(str_) == PyBytes_Size(python_string))
// not checking for embedded '\0'
if (!str_) { // errno is ENOMEM
PyErr_SetString(PyExc_MemoryError, "strdup in Py2NRNString");
}
} else { // Neither Unicode or PyBytes
PyErr_SetString(PyExc_TypeError, "Neither Unicode or PyBytes");
}
}
Py2NRNString(PyObject* python_string, bool disable_release = false);

~Py2NRNString() {
if (!disable_release_ && str_) {
Expand All @@ -44,53 +24,9 @@ class Py2NRNString {
inline bool err() const {
return str_ == NULL;
}
inline void set_pyerr(PyObject* type, const char* message) {
PyObject* ptype = NULL;
PyObject* pvalue = NULL;
PyObject* ptraceback = NULL;
if (err()) {
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
}
if (pvalue && ptype) {
PyObject* umes = PyUnicode_FromFormat("%s (Note: %S: %S)", message, ptype, pvalue);
PyErr_SetObject(type, umes); // umes is borrowed reference
Py_XDECREF(umes);
} else {
PyErr_SetString(type, message);
}
Py_XDECREF(ptype);
Py_XDECREF(pvalue);
Py_XDECREF(ptraceback);
}
inline char* get_pyerr() {
PyObject* ptype = NULL;
PyObject* pvalue = NULL;
PyObject* ptraceback = NULL;
if (err()) {
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
if (pvalue) {
PyObject* pstr = PyObject_Str(pvalue);
if (pstr) {
const char* err_msg = PyUnicode_AsUTF8(pstr);
if (err_msg) {
str_ = strdup(err_msg);
} else {
str_ = strdup("get_pyerr failed at PyUnicode_AsUTF8");
}
Py_XDECREF(pstr);
} else {
str_ = strdup("get_pyerr failed at PyObject_Str");
}
} else {
str_ = strdup("get_pyerr failed at PyErr_Fetch");
}
}
PyErr_Clear(); // in case could not turn pvalue into c_str.
Py_XDECREF(ptype);
Py_XDECREF(pvalue);
Py_XDECREF(ptraceback);
return str_;
}

void set_pyerr(PyObject* type, const char* message);
char* get_pyerr();

private:
Py2NRNString();
Expand All @@ -107,7 +43,7 @@ class Py2NRNString {
* hoc_execerr_ext("hoc message : %s", e.c_str());
* e will be automatically deleted even though execerror does not return.
*/
class PyErr2NRNString {
class NRN_EXPORT PyErr2NRNString {
public:
PyErr2NRNString() {
str_ = NULL;
Expand All @@ -123,35 +59,7 @@ class PyErr2NRNString {
return str_;
}

inline char* get_pyerr() {
PyObject* ptype = NULL;
PyObject* pvalue = NULL;
PyObject* ptraceback = NULL;
if (PyErr_Occurred()) {
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
if (pvalue) {
PyObject* pstr = PyObject_Str(pvalue);
if (pstr) {
const char* err_msg = PyUnicode_AsUTF8(pstr);
if (err_msg) {
str_ = strdup(err_msg);
} else {
str_ = strdup("get_pyerr failed at PyUnicode_AsUTF8");
}
Py_XDECREF(pstr);
} else {
str_ = strdup("get_pyerr failed at PyObject_Str");
}
} else {
str_ = strdup("get_pyerr failed at PyErr_Fetch");
}
}
PyErr_Clear(); // in case could not turn pvalue into c_str.
Py_XDECREF(ptype);
Py_XDECREF(pvalue);
Py_XDECREF(ptraceback);
return str_;
}
char* get_pyerr();

private:
PyErr2NRNString(const PyErr2NRNString&);
Expand Down
Loading

0 comments on commit fa7d530

Please sign in to comment.