Skip to content

Commit

Permalink
Refactoring to separate cython files
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Aug 9, 2024
1 parent 9810ede commit 0246b6a
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 133 deletions.
3 changes: 3 additions & 0 deletions treeple/tree/_honest_prune.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ cdef class HonestPruner(Splitter):
cdef intp_t pos # The current position to split left/right children
cdef intp_t n_missing # The number of missing values in the feature currently considered
cdef uint8_t missing_go_to_left

# TODO: only supports sparse for now.
cdef const float32_t[:, :] X

cdef int init(
Expand All @@ -32,6 +34,7 @@ cdef class HonestPruner(Splitter):
const uint8_t[::1] missing_values_in_feature_mask,
) except -1

# This function is not used, and should be disabled for pruners
cdef int node_split(
self,
ParentInfo* parent_record,
Expand Down
145 changes: 12 additions & 133 deletions treeple/tree/_honest_prune.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ from libc.stdlib cimport free, malloc
from libcpp.stack cimport stack
from sklearn.tree._tree cimport ParentInfo

from ._prune cimport _build_pruned_tree

TREE_LEAF = -1
TREE_UNDEFINED = -2
cdef intp_t _TREE_LEAF = TREE_LEAF
Expand Down Expand Up @@ -265,7 +267,7 @@ cdef class HonestPruner(Splitter):
Returns 0 if a split cannot be done, 1 if a split can be done
and -1 in case of failure to allocate memory (and raise MemoryError).
"""
pass
raise NotImplementedError("node_split is not used in honest pruning")


cdef _honest_prune(
Expand Down Expand Up @@ -327,10 +329,6 @@ cdef _honest_prune(
float64_t lower_bound, upper_bound
float64_t left_child_min, left_child_max, right_child_min, right_child_max, middle_value

cdef bint first = 0
cdef ParentInfo parent_record
_init_parent_record(&parent_record)

# find parent node ids and leaves
with nogil:
# Push the root node
Expand All @@ -352,9 +350,7 @@ cdef _honest_prune(
pruning_stack.pop()
start = stack_record.start
end = stack_record.end
parent_record.impurity = stack_record.impurity
parent_record.lower_bound = stack_record.lower_bound
parent_record.upper_bound = stack_record.upper_bound
impurity = stack_record.impurity
lower_bound = stack_record.lower_bound
upper_bound = stack_record.upper_bound

Expand All @@ -366,7 +362,7 @@ cdef _honest_prune(

# get the impurity to initialize passing into its children
if first:
parent_record.impurity = pruner.node_impurity()
impurity = pruner.node_impurity()
first = 0

# partition samples into left/right child based on the
Expand All @@ -377,6 +373,7 @@ cdef _honest_prune(
split_ptr.feature = orig_tree.nodes[node_idx].feature
invalid_split = pruner.check_node_partition_conditions(
split_ptr,
impurity,
lower_bound,
upper_bound
)
Expand All @@ -402,12 +399,12 @@ cdef _honest_prune(
# Current bounds must always be propagated to both children.
# If a monotonic constraint is active, bounds are used in
# node value clipping.
left_child_min = right_child_min = parent_record.lower_bound
left_child_max = right_child_max = parent_record.upper_bound
left_child_min = right_child_min = lower_bound
left_child_max = right_child_max = upper_bound
elif pruner.monotonic_cst[split_ptr.feature] == 1:
# Split on a feature with monotonic increase constraint
left_child_min = parent_record.lower_bound
right_child_max = parent_record.upper_bound
left_child_min = lower_bound
right_child_max = upper_bound

# Lower bound for right child and upper bound for left child
# are set to the same value.
Expand All @@ -416,8 +413,8 @@ cdef _honest_prune(
left_child_max = middle_value
else: # i.e. pruner.monotonic_cst[split.feature] == -1
# Split on a feature with monotonic decrease constraint
right_child_min = parent_record.lower_bound
left_child_max = parent_record.upper_bound
right_child_min = lower_bound
left_child_max = upper_bound

# Lower bound for left child and upper bound for right child
# are set to the same value.
Expand All @@ -444,121 +441,3 @@ cdef _honest_prune(

# free the memory created for the SplitRecord pointer
free(split_ptr)


from libc.stdint cimport INTPTR_MAX
from libc.string cimport memcpy


cdef struct BuildPrunedRecord:
intp_t start
intp_t depth
intp_t parent
bint is_left


cdef _build_pruned_tree(
Tree tree, # OUT
Tree orig_tree,
const uint8_t[:] leaves_in_subtree,
intp_t capacity
):
"""Build a pruned tree.
Build a pruned tree from the original tree by transforming the nodes in
``leaves_in_subtree`` into leaves.
Parameters
----------
tree : Tree
Location to place the pruned tree
orig_tree : Tree
Original tree
leaves_in_subtree : uint8_t memoryview, shape=(node_count, )
Boolean mask for leaves to include in subtree
capacity : intp_t
Number of nodes to initially allocate in pruned tree
"""
tree._resize(capacity)

cdef:
intp_t orig_node_id
intp_t new_node_id
intp_t depth
intp_t parent
bint is_left
bint is_leaf

# value_stride for original tree and new tree are the same
intp_t value_stride = orig_tree.value_stride
intp_t max_depth_seen = -1
intp_t rc = 0
Node* node
float64_t* orig_value_ptr
float64_t* new_value_ptr

stack[BuildPrunedRecord] prune_stack
BuildPrunedRecord stack_record

SplitRecord split

with nogil:
# push root node onto stack
prune_stack.push({"start": 0, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0})

while not prune_stack.empty():
stack_record = prune_stack.top()
prune_stack.pop()

orig_node_id = stack_record.start
depth = stack_record.depth
parent = stack_record.parent
is_left = stack_record.is_left

is_leaf = leaves_in_subtree[orig_node_id]
node = &orig_tree.nodes[orig_node_id]

# redefine to a SplitRecord to pass into _add_node
split.feature = node.feature
split.threshold = node.threshold

# protect against an infinite loop as a runtime error, when leaves_in_subtree
# are improperly set where a node is not marked as a leaf, but is a node
# in the original tree. Thus, it violates the assumption that the node
# is a leaf in the pruned tree, or has a descendant that will be pruned.
if (not is_leaf and node.left_child == _TREE_LEAF
and node.right_child == _TREE_LEAF):
raise ValueError(
"Node has reached a leaf in the original tree, but is not "
"marked as a leaf in the leaves_in_subtree mask."
)

new_node_id = tree._add_node(
parent, is_left, is_leaf, &split,
node.impurity, node.n_node_samples,
node.weighted_n_node_samples, node.missing_go_to_left)

if new_node_id == INTPTR_MAX:
rc = -1
break

# copy value from original tree to new tree
orig_value_ptr = orig_tree.value + value_stride * orig_node_id
new_value_ptr = tree.value + value_stride * new_node_id
memcpy(new_value_ptr, orig_value_ptr, sizeof(float64_t) * value_stride)

if not is_leaf:
# Push right child on stack
prune_stack.push({"start": node.right_child, "depth": depth + 1,
"parent": new_node_id, "is_left": 0})
# push left child on stack
prune_stack.push({"start": node.left_child, "depth": depth + 1,
"parent": new_node_id, "is_left": 1})

if depth > max_depth_seen:
max_depth_seen = depth

if rc >= 0:
tree.max_depth = max_depth_seen
if rc == -1:
raise MemoryError("pruning tree")
19 changes: 19 additions & 0 deletions treeple/tree/_prune.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copied from scikit-learn/tree/_tree.pyx

from libc.stdint cimport INTPTR_MAX
from libc.string cimport memcpy


cdef struct BuildPrunedRecord:
intp_t start
intp_t depth
intp_t parent
bint is_left


cdef void _build_pruned_tree(
Tree tree, # OUT
Tree orig_tree,
const uint8_t[:] leaves_in_subtree,
intp_t capacity
) noexcept
109 changes: 109 additions & 0 deletions treeple/tree/_prune.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# cython: boundscheck=False
# cython: wraparound=False
# cython: initializedcheck=False

cdef void _build_pruned_tree(
Tree tree, # OUT
Tree orig_tree,
const uint8_t[:] leaves_in_subtree,
intp_t capacity
) noexcept:
"""Build a pruned tree.
Build a pruned tree from the original tree by transforming the nodes in
``leaves_in_subtree`` into leaves.
Parameters
----------
tree : Tree
Location to place the pruned tree
orig_tree : Tree
Original tree
leaves_in_subtree : uint8_t memoryview, shape=(node_count, )
Boolean mask for leaves to include in subtree
capacity : intp_t
Number of nodes to initially allocate in pruned tree
"""
tree._resize(capacity)

cdef:
intp_t orig_node_id
intp_t new_node_id
intp_t depth
intp_t parent
bint is_left
bint is_leaf

# value_stride for original tree and new tree are the same
intp_t value_stride = orig_tree.value_stride
intp_t max_depth_seen = -1
intp_t rc = 0
Node* node
float64_t* orig_value_ptr
float64_t* new_value_ptr

stack[BuildPrunedRecord] prune_stack
BuildPrunedRecord stack_record

SplitRecord split

with nogil:
# push root node onto stack
prune_stack.push({"start": 0, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0})

while not prune_stack.empty():
stack_record = prune_stack.top()
prune_stack.pop()

orig_node_id = stack_record.start
depth = stack_record.depth
parent = stack_record.parent
is_left = stack_record.is_left

is_leaf = leaves_in_subtree[orig_node_id]
node = &orig_tree.nodes[orig_node_id]

# redefine to a SplitRecord to pass into _add_node
split.feature = node.feature
split.threshold = node.threshold

# protect against an infinite loop as a runtime error, when leaves_in_subtree
# are improperly set where a node is not marked as a leaf, but is a node
# in the original tree. Thus, it violates the assumption that the node
# is a leaf in the pruned tree, or has a descendant that will be pruned.
if (not is_leaf and node.left_child == _TREE_LEAF
and node.right_child == _TREE_LEAF):
raise ValueError(
"Node has reached a leaf in the original tree, but is not "
"marked as a leaf in the leaves_in_subtree mask."
)

new_node_id = tree._add_node(
parent, is_left, is_leaf, &split,
node.impurity, node.n_node_samples,
node.weighted_n_node_samples, node.missing_go_to_left)

if new_node_id == INTPTR_MAX:
rc = -1
break

# copy value from original tree to new tree
orig_value_ptr = orig_tree.value + value_stride * orig_node_id
new_value_ptr = tree.value + value_stride * new_node_id
memcpy(new_value_ptr, orig_value_ptr, sizeof(float64_t) * value_stride)

if not is_leaf:
# Push right child on stack
prune_stack.push({"start": node.right_child, "depth": depth + 1,
"parent": new_node_id, "is_left": 0})
# push left child on stack
prune_stack.push({"start": node.left_child, "depth": depth + 1,
"parent": new_node_id, "is_left": 1})

if depth > max_depth_seen:
max_depth_seen = depth

if rc >= 0:
tree.max_depth = max_depth_seen
if rc == -1:
raise MemoryError("pruning tree")
3 changes: 3 additions & 0 deletions treeple/tree/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ tree_extension_metadata = {
'_honest_prune':
{'sources': ['_honest_prune.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
'_prune':
{'sources': ['_prune.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
'_marginal':
{'sources': ['_marginal.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
Expand Down

0 comments on commit 0246b6a

Please sign in to comment.