Skip to content

Commit

Permalink
Adding workflow to test against main
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Aug 8, 2024
1 parent dbb8044 commit 15f6377
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ jobs:
pip install -r build_requirements.txt
pip install -r test_requirements.txt
- name: Install nightly wheels for scikit-learn (only for ubuntu 3.12)
if: ${{ matrix.python-version == '3.12' }} && ${{ matrix.os == 'ubuntu-latest' }}
run: |
pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple scikit-learn --force
- name: Prepare compiler cache
id: prep-ccache
shell: bash
Expand Down
3 changes: 3 additions & 0 deletions treeple/_lib/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ tree_extension_metadata = {
'_tree':
{'sources': ['./sklearn/tree/' + '_tree.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
'_partitioner':
{'sources': ['./sklearn/tree/' + '_partitioner.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
'_splitter':
{'sources': ['./sklearn/tree/' + '_splitter.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
Expand Down
9 changes: 2 additions & 7 deletions treeple/tree/_honest_prune.pxd
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
# from .._lib.sklearn.tree._tree import _build_pruned_tree

from .._lib.sklearn.tree._criterion cimport Criterion
from .._lib.sklearn.tree._splitter cimport (
SplitRecord,
Splitter,
shift_missing_values_to_left_if_required,
)
from .._lib.sklearn.tree._partitioner cimport shift_missing_values_to_left_if_required
from .._lib.sklearn.tree._splitter cimport SplitRecord, Splitter
from .._lib.sklearn.tree._tree cimport Node, ParentInfo, Tree
from .._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int8_t, intp_t, uint8_t, uint32_t

Expand Down
3 changes: 0 additions & 3 deletions treeple/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,6 @@ def _fit_leaves(self, X, y, sample_weight):
y = y_encoded
self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)

# XXX: implement honest pruning
print(self.honest_method)
if self.honest_method == "apply":
# Fit leaves using other subsample
honest_leaves = self.tree_.apply(X[self.honest_indices_])
Expand Down Expand Up @@ -744,7 +742,6 @@ def _fit_leaves(self, X, y, sample_weight):
pruned_tree, self.tree_, pruner, X, y, sample_weight, missing_values_in_feature_mask
)
self.tree_ = pruned_tree
# raise NotImplementedError("Pruning is not yet implemented.")

if self.n_outputs_ == 1:
self.n_classes_ = self.n_classes_[0]
Expand Down

0 comments on commit 15f6377

Please sign in to comment.