diff --git a/pycolmap/geometry/bindings.h b/pycolmap/geometry/bindings.h index 85f3d18..b3a8e2e 100644 --- a/pycolmap/geometry/bindings.h +++ b/pycolmap/geometry/bindings.h @@ -5,6 +5,7 @@ #include "pycolmap/geometry/homography_matrix.h" #include "pycolmap/helpers.h" +#include "pycolmap/pybind11_extension.h" #include @@ -20,7 +21,7 @@ using namespace pybind11::literals; void BindGeometry(py::module& m) { BindHomographyGeometry(m); - py::class_ PyRotation3d(m, "Rotation3d"); + py::class_ext_ PyRotation3d(m, "Rotation3d"); PyRotation3d.def(py::init([]() { return Eigen::Quaterniond::Identity(); })) .def(py::init(), "xyzw"_a, @@ -55,7 +56,7 @@ void BindGeometry(py::module& m) { py::implicitly_convertible(); MakeDataclass(PyRotation3d); - py::class_ PyRigid3d(m, "Rigid3d"); + py::class_ext_ PyRigid3d(m, "Rigid3d"); PyRigid3d.def(py::init<>()) .def(py::init()) .def(py::init([](const Eigen::Matrix3x4d& matrix) { @@ -87,7 +88,7 @@ void BindGeometry(py::module& m) { py::implicitly_convertible(); MakeDataclass(PyRigid3d); - py::class_ PySim3d(m, "Sim3d"); + py::class_ext_ PySim3d(m, "Sim3d"); PySim3d.def(py::init<>()) .def( py::init()) diff --git a/pycolmap/pybind11_extension.h b/pycolmap/pybind11_extension.h index a96d0c9..d7516c7 100644 --- a/pycolmap/pybind11_extension.h +++ b/pycolmap/pybind11_extension.h @@ -96,6 +96,38 @@ struct type_caster>> { } // namespace detail +template +class class_ext_ : public class_ { + public: + using Parent = class_; + using Parent::class_; // inherit constructors + using type = type_; + + template + class_ext_& def_readwrite(const char* name, D C::*pm, const Extra&... extra) { + static_assert( + std::is_same::value || std::is_base_of::value, + "def_readwrite() requires a class member (or base class member)"); + cpp_function fget([pm](type&c) -> D& { return c.*pm; }, is_method(*this)), + fset([pm](type&c, const D&value) { c.*pm = value; }, is_method(*this)); + this->def_property( + name, fget, fset, return_value_policy::reference_internal, extra...); + return *this; + } + + template + class_ext_& def(Args&&... args) { + Parent::def(std::forward(args)...); + return *this; + } + + template + class_ext_& def_property(Args&&... args) { + Parent::def_property(std::forward(args)...); + return *this; + } +}; + // Fix long-standing bug https://github.com/pybind/pybind11/issues/4529 // TODO(sarlinpe): remove when https://github.com/pybind/pybind11/pull/4972 // appears in the next release of pybind11. diff --git a/pycolmap/scene/point2D.h b/pycolmap/scene/point2D.h index 97917cf..909632f 100644 --- a/pycolmap/scene/point2D.h +++ b/pycolmap/scene/point2D.h @@ -45,7 +45,7 @@ void BindPoint2D(py::module& m) { return repr; }); - py::class_> PyPoint2D(m, "Point2D"); + py::class_ext_> PyPoint2D(m, "Point2D"); PyPoint2D.def(py::init<>()) .def(py::init(), "xy"_a, diff --git a/pycolmap/scene/point3D.h b/pycolmap/scene/point3D.h index acf4fd5..7dd8ef1 100644 --- a/pycolmap/scene/point3D.h +++ b/pycolmap/scene/point3D.h @@ -26,7 +26,7 @@ void BindPoint3D(py::module& m) { std::to_string(self.size()) + ")"; }); - py::class_> PyPoint3D(m, "Point3D"); + py::class_ext_> PyPoint3D(m, "Point3D"); PyPoint3D.def(py::init<>()) .def_readwrite("xyz", &Point3D::xyz) .def_readwrite("color", &Point3D::color)