From 2542240eb8a304914e1d3d534b259416e728ab1d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 18 Jul 2024 14:58:04 +0000 Subject: [PATCH] Update benchmark Python bindings for nanobind 2.0, and update to nanobind 2.0. Incorporates the nanobind_bazel change from https://github.com/google/benchmark/pull/1795. nanobind 2.0 reworked the nanobind::enum_ class so it uses a real Python enum or intenum rather than its previous hand-rolled implementation. https://github.com/google-deepmind/clrs/pull/119#issuecomment-1883834196 As a consequence of that change, nanobind now checks when casting an integer to a enum value that the integer corresponds to a valid enum. Counter::Flags is a bitmask, and many combinations are not valid enum members. This change: a) sets nb::is_arithmetic(), which means Counter::Flags becomes an IntEnum that can be freely cast to an integer. b) defines the | operator for flags to return an integer, not an enum, avoiding the error. c) changes Counter's constructor to accept an int, not a Counter::Flags enum. Since Counter::Flags is an IntEnum now, it can be freely coerced to an int. If https://github.com/wjakob/nanobind/pull/599 is merged into nanobind, then we can perhaps use a flag enum here instead. --- MODULE.bazel | 2 +- bindings/python/google_benchmark/benchmark.cc | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 8b98a7a027..b86da20adb 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -38,4 +38,4 @@ use_repo(pip, "tools_pip_deps") # -- bazel_dep definitions -- # -bazel_dep(name = "nanobind_bazel", version = "1.0.0", dev_dependency = True) +bazel_dep(name = "nanobind_bazel", version = "2.0.0", dev_dependency = True) \ No newline at end of file diff --git a/bindings/python/google_benchmark/benchmark.cc b/bindings/python/google_benchmark/benchmark.cc index f44476901c..64ffb92b48 100644 --- a/bindings/python/google_benchmark/benchmark.cc +++ b/bindings/python/google_benchmark/benchmark.cc @@ -118,7 +118,7 @@ NB_MODULE(_benchmark, m) { using benchmark::Counter; nb::class_ py_counter(m, "Counter"); - nb::enum_(py_counter, "Flags") + nb::enum_(py_counter, "Flags", nb::is_arithmetic()) .value("kDefaults", Counter::Flags::kDefaults) .value("kIsRate", Counter::Flags::kIsRate) .value("kAvgThreads", Counter::Flags::kAvgThreads) @@ -130,7 +130,9 @@ NB_MODULE(_benchmark, m) { .value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate) .value("kInvert", Counter::Flags::kInvert) .export_values() - .def(nb::self | nb::self); + .def("__or__", [](Counter::Flags a, Counter::Flags b) { + return static_cast(a) | static_cast(b); + }); nb::enum_(py_counter, "OneK") .value("kIs1000", Counter::OneK::kIs1000) @@ -138,10 +140,15 @@ NB_MODULE(_benchmark, m) { .export_values(); py_counter - .def(nb::init(), - nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults, - nb::arg("k") = Counter::kIs1000) - .def("__init__", ([](Counter *c, double value) { new (c) Counter(value); })) + .def( + "__init__", + [](Counter* c, double value, int flags, Counter::OneK oneK) { + new (c) Counter(value, static_cast(flags), oneK); + }, + nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults, + nb::arg("k") = Counter::kIs1000) + .def("__init__", + ([](Counter* c, double value) { new (c) Counter(value); })) .def_rw("value", &Counter::value) .def_rw("flags", &Counter::flags) .def_rw("oneK", &Counter::oneK)