Skip to content

Commit

Permalink
Cleanup several aspects of the Python bindings (copy of PR #1696) (#1696
Browse files Browse the repository at this point in the history
)

Signed-off-by: Ken Museth <[email protected]>
  • Loading branch information
kmuseth authored Oct 28, 2023
1 parent 77f28d1 commit fc4a559
Show file tree
Hide file tree
Showing 6 changed files with 462 additions and 215 deletions.
50 changes: 19 additions & 31 deletions openvdb/openvdb/python/pyGrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ inline bool
sharesWith(const GridType& grid, py::object other)
{
if (py::isinstance<GridType>(other)) {
typename GridType::ConstPtr otherGrid = other.cast<typename GridType::Ptr>();
typename GridType::ConstPtr otherGrid = py::cast<typename GridType::Ptr>(other);
return (&otherGrid->tree() == &grid.tree());
}
return false;
Expand Down Expand Up @@ -893,7 +893,7 @@ applyMap(const char* methodName, GridType& grid, py::object funcObj)

// Verify that the result is of type GridType::ValueType.
try {
result.cast<ValueT>();
py::cast<ValueT>(result);
} catch (py::cast_error&) {
std::ostringstream os;
os << "expected callable argument to ";
Expand All @@ -904,7 +904,7 @@ applyMap(const char* methodName, GridType& grid, py::object funcObj)
throw py::type_error(os.str());
}

it.setValue(result.cast<ValueT>());
it.setValue(py::cast<ValueT>(result));
}
}

Expand Down Expand Up @@ -955,7 +955,7 @@ struct TreeCombineOp
throw py::type_error(os.str());
}

result = resultObj.cast<ValueT>();
result = py::cast<ValueT>(resultObj);
}
py::function op;
};
Expand Down Expand Up @@ -1177,15 +1177,15 @@ class IterValueProxy
py::object getItem(py::object keyObj) const
{
if (py::isinstance<std::string>(keyObj)) {
const std::string key = keyObj.cast<std::string>();
const std::string key = py::cast<std::string>(keyObj);
if (key == "value") return py::cast(this->getValue());
else if (key == "active") return py::cast(this->getActive());
else if (key == "depth") return py::cast(this->getDepth());
else if (key == "min") return py::cast(this->getBBoxMin());
else if (key == "max") return py::cast(this->getBBoxMax());
else if (key == "count") return py::cast(this->getVoxelCount());
}
throw py::key_error(keyObj.attr("__repr__")().cast<std::string>());
throw py::key_error(py::cast<std::string>(keyObj.attr("__repr__")()));
return py::object();
}

Expand All @@ -1195,20 +1195,20 @@ class IterValueProxy
void setItem(py::object keyObj, py::object valObj)
{
if (py::isinstance<std::string>(keyObj)) {
const std::string key = keyObj.cast<std::string>();
const std::string key = py::cast<std::string>(keyObj);
if (key == "value") {
this->setValue(valObj.cast<ValueT>()); return;
this->setValue(py::cast<ValueT>(valObj)); return;
} else if (key == "active") {
this->setActive(valObj.cast<bool>()); return;
this->setActive(py::cast<bool>(valObj)); return;
} else if (this->hasKey(key)) {
std::ostringstream os;
os << "can't set attribute '";
os << keyObj.attr("__repr__")().cast<std::string>();
os << py::cast<std::string>(keyObj.attr("__repr__")());
os << "'";
throw py::attribute_error(os.str());
}
}
throw py::key_error(keyObj.attr("__repr__")().cast<std::string>());
throw py::key_error(py::cast<std::string>(keyObj.attr("__repr__")()));
}

bool operator==(const IterValueProxy& other) const
Expand All @@ -1235,7 +1235,7 @@ class IterValueProxy
}
// print ", ".join(valuesAsStrings)
py::object joined = py::str(", ").attr("join")(valuesAsStrings);
std::string s = joined.cast<std::string>();
std::string s = py::cast<std::string>(joined);
os << "{" << s << "}";
return os;
}
Expand Down Expand Up @@ -1379,13 +1379,9 @@ struct PickleSuite
}

// Construct a state tuple for the serialized Grid.
#if PY_MAJOR_VERSION >= 3
// Convert the byte string to a "bytes" sequence.
const std::string s = ostr.str();
py::bytes bytesObj(s);
#else
py::str bytesObj(ostr.str());
#endif
return py::make_tuple(bytesObj);
}

Expand All @@ -1397,25 +1393,16 @@ struct PickleSuite
std::string serialized;
if (!badState) {
// Extract the sequence containing the serialized Grid.
#if PY_MAJOR_VERSION >= 3
if (py::isinstance<py::bytes>(state[0]))
serialized = state[0].cast<py::bytes>();
#else
if (py::isinstance<std::string>(state[0]))
serialized = state[0].cast<std::string>();
#endif
serialized = py::cast<py::bytes>(state[0]);
else
badState = true;
}

if (badState) {
std::ostringstream os;
#if PY_MAJOR_VERSION >= 3
os << "expected (dict, bytes) tuple in call to __setstate__; found ";
#else
os << "expected (dict, str) tuple in call to __setstate__; found ";
#endif
os << state.attr("__repr__")().cast<std::string>();
os << py::cast<std::string>(state.attr("__repr__")());
throw py::value_error(os.str());
}

Expand Down Expand Up @@ -1457,14 +1444,15 @@ exportGrid(py::module_ m)
using ValueAllIterT = typename GridType::ValueAllIter;

const std::string pyGridTypeName = Traits::name();
const std::string defaultCtorDescr = "Initialize with a background value of "
+ pyutil::str(pyGrid::getZeroValue<GridType>()) + ".";
std::stringstream docstream;
docstream << "Initialize with a background value of " << pyGrid::getZeroValue<GridType>() << ".";
std::string docstring = docstream.str();

// Define the Grid wrapper class and make it the current scope.
py::class_<GridType, GridPtr, GridBase>(m,
/*classname=*/pyGridTypeName.c_str(),
/*docstring=*/(Traits::descr()).c_str())
.def(py::init<>(), defaultCtorDescr.c_str())
.def(py::init<>(), docstring.c_str())
.def(py::init<const ValueT&>(), py::arg("background"),
"Initialize with the given background value.")

Expand Down Expand Up @@ -1711,7 +1699,7 @@ exportGrid(py::module_ m)
IterWrap<GridType, ValueAllIterT>::wrap(m);

// Add the Python type object for this grid type to the module-level list.
m.attr("GridTypes").cast<py::list>().append(m.attr(pyGridTypeName.c_str()));
py::cast<py::list>(m.attr("GridTypes")).append(m.attr(pyGridTypeName.c_str()));
}

} // namespace pyGrid
Expand Down
8 changes: 2 additions & 6 deletions openvdb/openvdb/python/pyGridBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,8 @@ exportGridBase(py::module_ m)


auto getMetadataKeys = [](GridBase::ConstPtr grid) {
#if PY_MAJOR_VERSION >= 3
// Return an iterator over the "keys" view of a dict.
return py::make_key_iterator(static_cast<const MetaMap&>(*grid).beginMeta(), static_cast<const MetaMap&>(*grid).endMeta());
#else
return py::dict(py::cast(static_cast<const MetaMap&>(*grid))).iterkeys();
#endif
};


Expand All @@ -118,7 +114,7 @@ exportGridBase(py::module_ m)
MetaMap metamap;
metamap.insertMeta(name, *metadata);
// todo: Add/refactor out type_casters for each TypedMetadata from MetaMap's type_caster
return py::dict(py::cast(metamap))[py::str(name)].cast<py::object>();
return py::cast<py::object>(py::dict(py::cast(metamap))[py::str(name)]);
};


Expand All @@ -135,7 +131,7 @@ exportGridBase(py::module_ m)
// todo: Add/refactor out type_casters for each TypedMetadata from MetaMap's type_caster
py::dict dictObj;
dictObj[py::str(name)] = value;
MetaMap metamap = dictObj.cast<MetaMap>();
MetaMap metamap = py::cast<MetaMap>(dictObj);

if (Metadata::Ptr metadata = metamap[name]) {
grid->removeMeta(name);
Expand Down
72 changes: 0 additions & 72 deletions openvdb/openvdb/python/pyOpenVDBModule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,78 +371,6 @@ PYBIND11_MODULE(PY_OPENVDB_MODULE_NAME, m)

#undef PYOPENVDB_TRANSLATE_EXCEPTION

// Basic bindings for these Vec types are required to support them as
// default arguments to functions.
py::class_<openvdb::Coord>(m, "Coord")
.def(py::init<>())
.def(py::init<openvdb::Coord::Int32>())
.def(py::init<openvdb::Coord::Int32, openvdb::Coord::Int32, openvdb::Coord::Int32>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec2i>(m, "Vec2i")
.def(py::init<>())
.def(py::init<int32_t>())
.def(py::init<int32_t, int32_t>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec2f>(m, "Vec2f")
.def(py::init<>())
.def(py::init<float>())
.def(py::init<float, float>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec2d>(m, "Vec2d")
.def(py::init<>())
.def(py::init<double>())
.def(py::init<double, double>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec3i>(m, "Vec3i")
.def(py::init<>())
.def(py::init<int32_t>())
.def(py::init<int32_t, int32_t, int32_t>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec3f>(m, "Vec3f")
.def(py::init<>())
.def(py::init<float>())
.def(py::init<float, float, float>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec3d>(m, "Vec3d")
.def(py::init<>())
.def(py::init<double>())
.def(py::init<double, double, double>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec4i>(m, "Vec4i")
.def(py::init<>())
.def(py::init<int32_t>())
.def(py::init<int32_t, int32_t, int32_t, int32_t>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec4f>(m, "Vec4f")
.def(py::init<>())
.def(py::init<float>())
.def(py::init<float, float, float, float>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::Vec4d>(m, "Vec4d")
.def(py::init<>())
.def(py::init<double>())
.def(py::init<double, double, double, double>())
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<openvdb::PointDataIndex32>(m, "PointDataIndex32")
.def(py::init<openvdb::Index32>(), py::arg("i") = openvdb::Index32(0));

Expand Down
19 changes: 3 additions & 16 deletions openvdb/openvdb/python/pyTransform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,8 @@ struct PickleSuite

// Construct a state tuple comprising the version numbers of
// the serialization format and the serialized Transform.
#if PY_MAJOR_VERSION >= 3
// Convert the byte string to a "bytes" sequence.
py::bytes bytesObj(ostr.str());
#else
py::str bytesObj(ostr.str());
#endif
return py::make_tuple(
uint32_t(OPENVDB_LIBRARY_MAJOR_VERSION),
uint32_t(OPENVDB_LIBRARY_MINOR_VERSION),
Expand All @@ -125,7 +121,7 @@ struct PickleSuite
uint32_t version[3] = { 0, 0, 0 };
for (int i = 0; i < 3 && !badState; ++i) {
if (py::isinstance<py::int_>(state[idx[i]]))
version[i] = state[idx[i]].cast<uint32_t>();
version[i] = py::cast<uint32_t>(state[idx[i]]);
else badState = true;
}
libVersion.first = version[0];
Expand All @@ -137,24 +133,15 @@ struct PickleSuite
if (!badState) {
// Extract the sequence containing the serialized Transform.
py::object bytesObj = state[int(STATE_XFORM)];
#if PY_MAJOR_VERSION >= 3
if (py::isinstance<py::bytes>(bytesObj))
serialized = bytesObj.cast<py::bytes>();
#else
if (py::isinstance<std::string>(bytesObj))
serialized = bytesObj.cast<std::string>();
#endif
serialized = py::cast<py::bytes>(bytesObj);
else badState = true;
}

if (badState) {
std::ostringstream os;
#if PY_MAJOR_VERSION >= 3
os << "expected (int, int, int, bytes) tuple in call to __setstate__; found ";
#else
os << "expected (int, int, int, str) tuple in call to __setstate__; found ";
#endif
os << state.attr("__repr__")().cast<std::string>();
os << py::cast<std::string>(state.attr("__repr__")());
throw py::value_error(os.str());
}

Expand Down
Loading

0 comments on commit fc4a559

Please sign in to comment.