From 4444d04ff06468708caadef895b537fe8327debb Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 23 Aug 2023 10:04:29 -0400 Subject: [PATCH 01/11] ENH optimize sktree imports --- sktree/ensemble/_honest_forest.py | 2 +- sktree/ensemble/_supervised_forest.py | 4 ++-- sktree/ensemble/_unsupervised_forest.py | 6 +++--- sktree/tree/_honest_tree.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sktree/ensemble/_honest_forest.py b/sktree/ensemble/_honest_forest.py index 050cb906e..f46daf52d 100644 --- a/sktree/ensemble/_honest_forest.py +++ b/sktree/ensemble/_honest_forest.py @@ -8,7 +8,7 @@ from sklearn.ensemble._base import _partition_estimators from sklearn.utils.validation import check_is_fitted, check_X_y -from sktree._lib.sklearn.ensemble._forest import ForestClassifier +from .._lib.sklearn.ensemble._forest import ForestClassifier from .._lib.sklearn.tree import _tree as _sklearn_tree from ..tree import HonestTreeClassifier diff --git a/sktree/ensemble/_supervised_forest.py b/sktree/ensemble/_supervised_forest.py index fa04f0f06..70687cef4 100644 --- a/sktree/ensemble/_supervised_forest.py +++ b/sktree/ensemble/_supervised_forest.py @@ -1,7 +1,7 @@ from sklearn.utils._param_validation import StrOptions -from sktree._lib.sklearn.ensemble._forest import ForestClassifier, ForestRegressor -from sktree.tree import ( +from .._lib.sklearn.ensemble._forest import ForestClassifier, ForestRegressor +from ..tree import ( ObliqueDecisionTreeClassifier, ObliqueDecisionTreeRegressor, PatchObliqueDecisionTreeClassifier, diff --git a/sktree/ensemble/_unsupervised_forest.py b/sktree/ensemble/_unsupervised_forest.py index e369d57b5..d84be0aa6 100644 --- a/sktree/ensemble/_unsupervised_forest.py +++ b/sktree/ensemble/_unsupervised_forest.py @@ -23,9 +23,9 @@ from sklearn.utils.parallel import Parallel, delayed from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_random_state -from sktree._lib.sklearn.ensemble._forest import BaseForest -from sktree._lib.sklearn.tree._tree import DTYPE -from sktree.tree import UnsupervisedDecisionTree, UnsupervisedObliqueDecisionTree +from .._lib.sklearn.ensemble._forest import BaseForest +from .._lib.sklearn.tree._tree import DTYPE +from ..tree import UnsupervisedDecisionTree, UnsupervisedObliqueDecisionTree from ..tree._neighbors import SimMatrixMixin diff --git a/sktree/tree/_honest_tree.py b/sktree/tree/_honest_tree.py index b8991230e..4ffa72c0f 100644 --- a/sktree/tree/_honest_tree.py +++ b/sktree/tree/_honest_tree.py @@ -8,8 +8,8 @@ from sklearn.utils.multiclass import _check_partial_fit_first_call, check_classification_targets from sklearn.utils.validation import check_is_fitted, check_X_y -from sktree._lib.sklearn.tree import DecisionTreeClassifier -from sktree._lib.sklearn.tree._classes import BaseDecisionTree +from .._lib.sklearn.tree import DecisionTreeClassifier +from .._lib.sklearn.tree._classes import BaseDecisionTree class HonestTreeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseDecisionTree): From be877d2c9e3aa0883d6a54e646e8fef8487369e7 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 23 Aug 2023 12:50:01 -0400 Subject: [PATCH 02/11] ENH add SDF class for streaming trees --- sktree/experimental/__init__.py | 3 +- sktree/experimental/meson.build | 3 +- sktree/experimental/sdf.py | 264 ++++++++++++++++++++++++++++++++ 3 files changed, 268 insertions(+), 2 deletions(-) create mode 100644 sktree/experimental/sdf.py diff --git a/sktree/experimental/__init__.py b/sktree/experimental/__init__.py index bf88ce488..67ae6ba18 100644 --- a/sktree/experimental/__init__.py +++ b/sktree/experimental/__init__.py @@ -1,4 +1,4 @@ -from . import mutual_info, simulate +from . import mutual_info, simulate, sdf from .mutual_info import ( cmi_from_entropy, cmi_gaussian, @@ -9,3 +9,4 @@ mi_gaussian, mutual_info_ksg, ) +from .sdf import StreamDecisionForest diff --git a/sktree/experimental/meson.build b/sktree/experimental/meson.build index e92935a50..2be7ae9e8 100644 --- a/sktree/experimental/meson.build +++ b/sktree/experimental/meson.build @@ -2,6 +2,7 @@ python_sources = [ '__init__.py', 'mutual_info.py', 'simulate.py', + 'sdf.py', ] py3.install_sources( @@ -10,4 +11,4 @@ py3.install_sources( subdir: 'sktree/experimental' ) -subdir('tests') \ No newline at end of file +subdir('tests') diff --git a/sktree/experimental/sdf.py b/sktree/experimental/sdf.py new file mode 100644 index 000000000..029631488 --- /dev/null +++ b/sktree/experimental/sdf.py @@ -0,0 +1,264 @@ +""" +Main Author: Haoyin Xu +Corresponding Email: haoyinxu@gmail.com +""" +# import the necessary packages +import numpy as np +from joblib import Parallel, delayed + +from .._lib.sklearn.ensemble._forest import ( + RandomForestClassifier, + _generate_sample_indices, + _get_n_samples_bootstrap, +) +from sklearn.metrics import accuracy_score +from .._lib.sklearn.tree import DecisionTreeClassifier +from sklearn.utils.multiclass import _check_partial_fit_first_call + + +def _partial_fit(tree, X, y, n_samples_bootstrap, classes): + """Internal function to partially fit a tree.""" + indices = _generate_sample_indices(tree.random_state, X.shape[0], n_samples_bootstrap) + tree.partial_fit(X[indices, :], y[indices], classes=classes) + + return tree + + +class StreamDecisionForest(RandomForestClassifier): + """ + A class used to represent a naive ensemble of + random stream decision trees. + + Parameters + ---------- + n_estimators : int, default=100 + An integer that represents the number of stream decision trees. + + splitter : {"best", "random"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split and "random" to choose + the best random split. + + max_features : {"sqrt", "log2"}, int or float, default="sqrt" + The number of features to consider when looking for the best split: + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `round(max_features * n_features)` features are considered at each + split. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + bootstrap : bool, default=True + Whether bootstrap samples are used when building trees. If False, the + whole dataset is used to build each tree. + + n_jobs : int, default=None + The number of jobs to run in parallel. + + max_samples : int or float, default=None + If bootstrap is True, the number of samples to draw from X + to train each base estimator. + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0.0, 1.0]`. + + n_swaps : int, default=1 + The number of trees to swap at each partial fitting. The actual + swaps occur with `1/n_batches_` probability. + + Attributes + ---------- + estimators_ : list of sklearn.tree.DecisionTreeClassifier + An internal list that contains all + sklearn.tree.DecisionTreeClassifier. + + classes_ : list of all unique class labels + An internal list that stores class labels after the first call + to `partial_fit`. + + n_batches_ : int + The number of batches seen with `partial_fit`. + """ + + def __init__( + self, + n_estimators=100, + *, + criterion="gini", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="sqrt", + max_leaf_nodes=None, + min_impurity_decrease=0.0, + bootstrap=True, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + class_weight=None, + ccp_alpha=0.0, + max_samples=None, + max_bins=None, + store_leaf_values=False, + monotonic_cst=None, + n_swaps=1, + ): + + super().__init__( + n_estimators=n_estimators, + criterion=criterion, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + min_impurity_decrease=min_impurity_decrease, + ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start, + class_weight=class_weight, + max_samples=max_samples, + max_bins=max_bins, + store_leaf_values=store_leaf_values, + ) + self.n_batches_ = 0 + self.n_swaps = n_swaps + + def fit(self, X, y, classes=None): + """ + Partially fits the forest to data X with labels y. + + Parameters + ---------- + X : ndarray + Input data matrix. + + y : ndarray + Output (i.e. response data matrix). + + classes : ndarray, default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + Returns + ------- + self : StreamDecisionForest + The object itself. + """ + self.n_batches_ = 1 + + return super().fit(X, y, classes=classes) + + def partial_fit(self, X, y, sample_weight=None, classes=None): + """ + Partially fits the forest to data X with labels y. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + classes : ndarray, default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + Returns + ------- + self : StreamDecisionForest + The object itself. + """ + self.n_batches_ += 1 + self._validate_params() + + # validate input parameters + first_call = _check_partial_fit_first_call(self, classes=classes) + + # Fit if no tree exists yet + if first_call: + self.fit( + X, + y, + sample_weight=sample_weight, + classes=classes, + ) + return self + + if self.bootstrap: + n_samples_bootstrap = _get_n_samples_bootstrap(X.shape[0], self.max_samples) + else: + n_samples_bootstrap = X.shape[0] + + # Calculate probability of swaps + swap_prob = 1 / self.n_batches_ + if self.n_swaps > 0 and self.n_batches_ > 2 and np.random.random() <= swap_prob: + # Evaluate forest performance + results = Parallel(n_jobs=self.n_jobs)( + delayed(tree.predict)(X) for tree in self.estimators_ + ) + + # Sort predictions by accuracy + acc_l = [] + for idx, result in enumerate(results): + acc_l.append([accuracy_score(result, y), idx]) + acc_l = sorted(acc_l, key=lambda x: x[0]) + + # Generate new trees + new_trees = Parallel(n_jobs=self.n_jobs)( + delayed(_partial_fit)( + DecisionTreeClassifier( + criterion=self.criterion, + splitter=self.splitter, + max_depth=self.max_depth, + min_samples_split=self.min_samples_split, + min_samples_leaf=self.min_samples_leaf, + min_weight_fraction_leaf=self.min_weight_fraction_leaf, + max_features=self.max_features, + max_leaf_nodes=self.max_leaf_nodes, + class_weight=cself.lass_weight, + random_state=self.random_state, + min_impurity_decrease=self.min_impurity_decrease, + monotonic_cst=self.monotonic_cst, + ccp_alpha=self.ccp_alpha, + store_leaf_values=self.store_leaf_values, + ), + X, + y, + n_samples_bootstrap=n_samples_bootstrap, + classes=self.classes_, + ) + for i in range(self.n_swaps) + ) + + # Swap worst performing trees with new trees + for i in range(self.n_swaps): + self.estimators_[acc_l[i][1]] = new_trees[i] + + # Update existing stream decision trees + super().partial_fit(X, y, classes=classes) + + return self From c494b3ae8c5b403d0a148c3273850c25aac438d1 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 23 Aug 2023 12:55:37 -0400 Subject: [PATCH 03/11] FIX correct import orders --- sktree/ensemble/_honest_forest.py | 1 - sktree/ensemble/_supervised_forest.py | 1 - sktree/ensemble/_unsupervised_forest.py | 1 - sktree/experimental/__init__.py | 2 +- sktree/experimental/sdf.py | 4 ++-- 5 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sktree/ensemble/_honest_forest.py b/sktree/ensemble/_honest_forest.py index f46daf52d..1c5edbefb 100644 --- a/sktree/ensemble/_honest_forest.py +++ b/sktree/ensemble/_honest_forest.py @@ -9,7 +9,6 @@ from sklearn.utils.validation import check_is_fitted, check_X_y from .._lib.sklearn.ensemble._forest import ForestClassifier - from .._lib.sklearn.tree import _tree as _sklearn_tree from ..tree import HonestTreeClassifier diff --git a/sktree/ensemble/_supervised_forest.py b/sktree/ensemble/_supervised_forest.py index 70687cef4..403fc85df 100644 --- a/sktree/ensemble/_supervised_forest.py +++ b/sktree/ensemble/_supervised_forest.py @@ -7,7 +7,6 @@ PatchObliqueDecisionTreeClassifier, PatchObliqueDecisionTreeRegressor, ) - from ..tree._neighbors import SimMatrixMixin diff --git a/sktree/ensemble/_unsupervised_forest.py b/sktree/ensemble/_unsupervised_forest.py index d84be0aa6..cdadaf62b 100644 --- a/sktree/ensemble/_unsupervised_forest.py +++ b/sktree/ensemble/_unsupervised_forest.py @@ -26,7 +26,6 @@ from .._lib.sklearn.ensemble._forest import BaseForest from .._lib.sklearn.tree._tree import DTYPE from ..tree import UnsupervisedDecisionTree, UnsupervisedObliqueDecisionTree - from ..tree._neighbors import SimMatrixMixin diff --git a/sktree/experimental/__init__.py b/sktree/experimental/__init__.py index 67ae6ba18..cdf4b4295 100644 --- a/sktree/experimental/__init__.py +++ b/sktree/experimental/__init__.py @@ -1,4 +1,4 @@ -from . import mutual_info, simulate, sdf +from . import mutual_info, sdf, simulate from .mutual_info import ( cmi_from_entropy, cmi_gaussian, diff --git a/sktree/experimental/sdf.py b/sktree/experimental/sdf.py index 029631488..3cc1bf2e5 100644 --- a/sktree/experimental/sdf.py +++ b/sktree/experimental/sdf.py @@ -5,15 +5,15 @@ # import the necessary packages import numpy as np from joblib import Parallel, delayed +from sklearn.metrics import accuracy_score +from sklearn.utils.multiclass import _check_partial_fit_first_call from .._lib.sklearn.ensemble._forest import ( RandomForestClassifier, _generate_sample_indices, _get_n_samples_bootstrap, ) -from sklearn.metrics import accuracy_score from .._lib.sklearn.tree import DecisionTreeClassifier -from sklearn.utils.multiclass import _check_partial_fit_first_call def _partial_fit(tree, X, y, n_samples_bootstrap, classes): From 5f773350482246dae52fd5cdf25feb8f186d89d3 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 23 Aug 2023 12:58:10 -0400 Subject: [PATCH 04/11] FIX correct variable --- sktree/experimental/sdf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sktree/experimental/sdf.py b/sktree/experimental/sdf.py index 3cc1bf2e5..d487a9958 100644 --- a/sktree/experimental/sdf.py +++ b/sktree/experimental/sdf.py @@ -239,7 +239,7 @@ def partial_fit(self, X, y, sample_weight=None, classes=None): min_weight_fraction_leaf=self.min_weight_fraction_leaf, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, - class_weight=cself.lass_weight, + class_weight=self.lass_weight, random_state=self.random_state, min_impurity_decrease=self.min_impurity_decrease, monotonic_cst=self.monotonic_cst, From f4b64152ea2cdca9d75f728c003cb1bea8e6ffb0 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 23 Aug 2023 14:04:15 -0400 Subject: [PATCH 05/11] TST add unit tests for SDF --- sktree/experimental/tests/test_sdf.py | 120 ++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 sktree/experimental/tests/test_sdf.py diff --git a/sktree/experimental/tests/test_sdf.py b/sktree/experimental/tests/test_sdf.py new file mode 100644 index 000000000..ff2ef33d9 --- /dev/null +++ b/sktree/experimental/tests/test_sdf.py @@ -0,0 +1,120 @@ +import numpy as np +import pytest +from sklearn import datasets +from sklearn.metrics import accuracy_score, r2_score +from sklearn.utils.estimator_checks import parametrize_with_checks + +from sktree.experimental import StreamDecisionForest + +CLF_CRITERIONS = ("gini", "entropy") + +# also load the iris dataset +# and randomly permute it +iris = datasets.load_iris() +rng = np.random.RandomState(1) +perm = rng.permutation(iris.target.size) +iris.data = iris.data[perm] +iris.target = iris.target[perm] + + +def test_toy_accuracy(): + clf = StreamDecisionForest(n_estimators=10) + X = np.ones((20, 4)) + X[10:] *= -1 + y = [0] * 10 + [1] * 10 + clf = clf.fit(X, y) + np.testing.assert_array_equal(clf.predict(X), y) + + +def test_first_fit(): + clf = StreamDecisionForest(n_estimators=10) + with pytest.raises( + ValueError, match="classes must be passed on the first call to partial_fit." + ): + clf.partial_fit(iris.data, iris.target) + + +@pytest.mark.parametrize("criterion", ["gini", "entropy"]) +@pytest.mark.parametrize("max_features", [None, 2]) +def test_iris(criterion, max_features): + # Check consistency on dataset iris. + clf = StreamDecisionForest( + criterion=criterion, + random_state=0, + max_features=max_features, + n_estimators=10, + ) + + clf.partial_fit(iris.data, iris.target, classes=np.unique(iris.target)) + score = accuracy_score(clf.predict(iris.data), iris.target) + + assert score > 0.5 and score < 1.0, "Failed with {0}, criterion = {1} and score = {2}".format( + "SDF", criterion, score + ) + + score = accuracy_score(clf.predict(iris.data), clf.predict_proba(iris.data).argmax(1)) + assert score == 1.0, "Failed with {0}, criterion = {1} and score = {2}".format( + "SDF", criterion, score + ) + + clf.partial_fit(iris.data, iris.target) + score = accuracy_score(clf.predict(iris.data), iris.target) + + assert ( + score > 0.5 and score < 1.0 + ), "Failed partial_fit with {0}, criterion = {1} and score = {2}".format( + "SDF", criterion, score + ) + + score = accuracy_score(clf.predict(iris.data), clf.predict_proba(iris.data).argmax(1)) + assert score == 1.0, "Failed partial_fit with {0}, criterion = {1} and score = {2}".format( + "SDF", criterion, score + ) + + +@pytest.mark.parametrize("criterion", ["gini", "entropy"]) +@pytest.mark.parametrize("max_features", [None, 2]) +def test_iris_multi(criterion, max_features): + # Check consistency on dataset iris. + clf = StreamDecisionForest( + criterion=criterion, + random_state=0, + max_features=max_features, + n_estimators=10, + ) + + second_y = np.concatenate([(np.ones(50) * 3), (np.ones(50) * 4), (np.ones(50) * 5)]) + + X = iris.data + y = np.stack((iris.target, second_y[perm])).T + + clf.fit(X, y) + score = r2_score(clf.predict(X), y) + assert score > 0.9 and score < 1.0, "Failed with {0}, criterion = {1} and score = {2}".format( + "SDF", criterion, score + ) + + +def test_max_samples(): + max_samples_list = [8, 0.5, None] + depths = [] + X = rng.normal(0, 1, (100, 2)) + X[:50] *= -1 + y = [0, 1] * 50 + for ms in max_samples_list: + uf = StreamDecisionForest(n_estimators=2, random_state=0, max_samples=ms, bootstrap=True) + uf = uf.fit(X, y) + depths.append(uf.estimators_[0].get_depth()) + + assert all(np.diff(depths) > 0) + + +@parametrize_with_checks([StreamDecisionForest(n_estimators=10, random_state=0)]) +def test_sklearn_compatible_estimator(estimator, check): + # 1. check_class_weight_classifiers is not supported since it requires sample weight + # XXX: can include this "generalization" in the future if it's useful + if check.func.__name__ in [ + "check_class_weight_classifiers", + ]: + pytest.skip() + check(estimator) From 9cba01b842e83a05875fca1a3ed22a50f94de031 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 23 Aug 2023 14:11:41 -0400 Subject: [PATCH 06/11] FIX add test file to source --- sktree/experimental/tests/meson.build | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sktree/experimental/tests/meson.build b/sktree/experimental/tests/meson.build index ac5c5d40f..b5d1ef79c 100644 --- a/sktree/experimental/tests/meson.build +++ b/sktree/experimental/tests/meson.build @@ -2,10 +2,11 @@ python_sources = [ '__init__.py', 'test_mutual_info.py', 'test_simulate.py', + 'test_sdf.py', ] py3.install_sources( python_sources, pure: false, subdir: 'sktree/experimental/tests' -) \ No newline at end of file +) From 8ab3566f861219ea0b61066f3efe353419fa49a3 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 23 Aug 2023 14:29:54 -0400 Subject: [PATCH 07/11] FIX correct param --- sktree/experimental/sdf.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sktree/experimental/sdf.py b/sktree/experimental/sdf.py index d487a9958..fafa028ae 100644 --- a/sktree/experimental/sdf.py +++ b/sktree/experimental/sdf.py @@ -132,10 +132,9 @@ def __init__( max_bins=max_bins, store_leaf_values=store_leaf_values, ) - self.n_batches_ = 0 self.n_swaps = n_swaps - def fit(self, X, y, classes=None): + def fit(self, X, y, sample_weight=None, classes=None): """ Partially fits the forest to data X with labels y. @@ -159,7 +158,7 @@ def fit(self, X, y, classes=None): """ self.n_batches_ = 1 - return super().fit(X, y, classes=classes) + return super().fit(X, y, sample_weight=sample_weight, classes=classes) def partial_fit(self, X, y, sample_weight=None, classes=None): """ @@ -192,7 +191,6 @@ def partial_fit(self, X, y, sample_weight=None, classes=None): self : StreamDecisionForest The object itself. """ - self.n_batches_ += 1 self._validate_params() # validate input parameters @@ -207,6 +205,7 @@ def partial_fit(self, X, y, sample_weight=None, classes=None): classes=classes, ) return self + self.n_batches_ += 1 if self.bootstrap: n_samples_bootstrap = _get_n_samples_bootstrap(X.shape[0], self.max_samples) From 11759c8c93987920860b302d6f05859d57527a1f Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 12 Sep 2023 09:49:14 -0400 Subject: [PATCH 08/11] ENH update submodule --- sktree/_lib/sklearn_fork | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sktree/_lib/sklearn_fork b/sktree/_lib/sklearn_fork index 68015082c..e2fee00aa 160000 --- a/sktree/_lib/sklearn_fork +++ b/sktree/_lib/sklearn_fork @@ -1 +1 @@ -Subproject commit 68015082cc740d7859fad964a77f9e684544d868 +Subproject commit e2fee00aa461c21b8cfa59eb907d27972415c99b From 3f216abbf3f83c1bbab3af69d3dd527bf060f0ae Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 12 Sep 2023 09:56:47 -0400 Subject: [PATCH 09/11] FIX allow test to have perfect score --- sktree/experimental/tests/test_sdf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sktree/experimental/tests/test_sdf.py b/sktree/experimental/tests/test_sdf.py index ff2ef33d9..f92e58466 100644 --- a/sktree/experimental/tests/test_sdf.py +++ b/sktree/experimental/tests/test_sdf.py @@ -61,7 +61,7 @@ def test_iris(criterion, max_features): score = accuracy_score(clf.predict(iris.data), iris.target) assert ( - score > 0.5 and score < 1.0 + score > 0.5 and score <= 1.0 ), "Failed partial_fit with {0}, criterion = {1} and score = {2}".format( "SDF", criterion, score ) @@ -90,7 +90,7 @@ def test_iris_multi(criterion, max_features): clf.fit(X, y) score = r2_score(clf.predict(X), y) - assert score > 0.9 and score < 1.0, "Failed with {0}, criterion = {1} and score = {2}".format( + assert score > 0.9 and score <= 1.0, "Failed with {0}, criterion = {1} and score = {2}".format( "SDF", criterion, score ) From 2e29a051d2d9c7e4753daa2e529e64b6143aa152 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 12 Sep 2023 10:02:01 -0400 Subject: [PATCH 10/11] FIX allow test to have perfect score --- sktree/experimental/tests/test_sdf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sktree/experimental/tests/test_sdf.py b/sktree/experimental/tests/test_sdf.py index f92e58466..cf093b0e6 100644 --- a/sktree/experimental/tests/test_sdf.py +++ b/sktree/experimental/tests/test_sdf.py @@ -48,7 +48,7 @@ def test_iris(criterion, max_features): clf.partial_fit(iris.data, iris.target, classes=np.unique(iris.target)) score = accuracy_score(clf.predict(iris.data), iris.target) - assert score > 0.5 and score < 1.0, "Failed with {0}, criterion = {1} and score = {2}".format( + assert score > 0.5 and score <= 1.0, "Failed with {0}, criterion = {1} and score = {2}".format( "SDF", criterion, score ) From 5945f35177632d12a23874533c2e5e5c5de19354 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 12 Sep 2023 11:09:15 -0400 Subject: [PATCH 11/11] DOC add changelog --- doc/whats_new/v0.2.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index 4321069e3..5d4112e07 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -30,6 +30,7 @@ Changelog - |Feature| Implementation of ExtraObliqueDecisionTreeClassifier, ExtraObliqueDecisionTreeRegressor by `SUKI-O`_ (:pr:`75`) - |Efficiency| Around 1.5-2x speed improvement for unsupervised forests, by `Adam Li`_ (:pr:`114`) - |API| Allow ``sqrt`` and ``log2`` keywords to be used for ``min_samples_split`` parameter in unsupervised forests, by `Adam Li`_ (:pr:`114`) +- |Feature| Implementation of StreamDecisionForest, by `Haoyin Xu`_ and `Adam Li`_ (:pr:`116`) Code and Documentation Contributors @@ -40,3 +41,4 @@ the project since version inception, including: * `Adam Li`_ * `SUKI-O`_ +* `Haoyin Xu`_