Skip to content

Commit

Permalink
Update benchmark Python bindings for nanobind 2.0, and update to nano…
Browse files Browse the repository at this point in the history
…bind 2.0.

Incorporates the nanobind_bazel change from google#1795.

nanobind 2.0 reworked the nanobind::enum_ class so it uses a real Python enum or intenum rather than its previous hand-rolled implementation.
google-deepmind/clrs#119 (comment)

As a consequence of that change, nanobind now checks when casting an integer to a enum value that the integer corresponds to a valid enum. Counter::Flags is a bitmask, and many combinations are not valid enum members.

This change:
a) sets nb::is_arithmetic(), which means Counter::Flags becomes an IntEnum that can be freely cast to an integer.
b) defines the | operator for flags to return an integer, not an enum, avoiding the error.
c) changes Counter's constructor to accept an int, not a Counter::Flags enum. Since Counter::Flags is an IntEnum now, it can be freely coerced to an int.

If wjakob/nanobind#599 is merged into nanobind, then we can perhaps use a flag enum here instead.
  • Loading branch information
hawkinsp committed Jul 18, 2024
1 parent a6ad7fb commit 2542240
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ use_repo(pip, "tools_pip_deps")

# -- bazel_dep definitions -- #

bazel_dep(name = "nanobind_bazel", version = "1.0.0", dev_dependency = True)
bazel_dep(name = "nanobind_bazel", version = "2.0.0", dev_dependency = True)
19 changes: 13 additions & 6 deletions bindings/python/google_benchmark/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ NB_MODULE(_benchmark, m) {
using benchmark::Counter;
nb::class_<Counter> py_counter(m, "Counter");

nb::enum_<Counter::Flags>(py_counter, "Flags")
nb::enum_<Counter::Flags>(py_counter, "Flags", nb::is_arithmetic())
.value("kDefaults", Counter::Flags::kDefaults)
.value("kIsRate", Counter::Flags::kIsRate)
.value("kAvgThreads", Counter::Flags::kAvgThreads)
Expand All @@ -130,18 +130,25 @@ NB_MODULE(_benchmark, m) {
.value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate)
.value("kInvert", Counter::Flags::kInvert)
.export_values()
.def(nb::self | nb::self);
.def("__or__", [](Counter::Flags a, Counter::Flags b) {
return static_cast<int>(a) | static_cast<int>(b);
});

nb::enum_<Counter::OneK>(py_counter, "OneK")
.value("kIs1000", Counter::OneK::kIs1000)
.value("kIs1024", Counter::OneK::kIs1024)
.export_values();

py_counter
.def(nb::init<double, Counter::Flags, Counter::OneK>(),
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
nb::arg("k") = Counter::kIs1000)
.def("__init__", ([](Counter *c, double value) { new (c) Counter(value); }))
.def(
"__init__",
[](Counter* c, double value, int flags, Counter::OneK oneK) {
new (c) Counter(value, static_cast<Counter::Flags>(flags), oneK);
},
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
nb::arg("k") = Counter::kIs1000)
.def("__init__",
([](Counter* c, double value) { new (c) Counter(value); }))
.def_rw("value", &Counter::value)
.def_rw("flags", &Counter::flags)
.def_rw("oneK", &Counter::oneK)
Expand Down

0 comments on commit 2542240

Please sign in to comment.