diff --git a/examples/plot_MI_gigantic_hypothesis_testing_forest.py b/examples/plot_MI_gigantic_hypothesis_testing_forest.py new file mode 100644 index 000000000..2af3e4412 --- /dev/null +++ b/examples/plot_MI_gigantic_hypothesis_testing_forest.py @@ -0,0 +1,118 @@ +""" +=========================================================== +Mutual Information for Gigantic Hypothesis Testing (MIGHT) +=========================================================== + +An example using :class:`~sktree.FeatureImportanceForestClassifier` for nonparametric +multivariate hypothesis test, on simulated datasets. Here, we present a simulation +of how MIGHT is used to test the hypothesis that a "feature set is important for +predicting the target". + +We simulate a dataset with 1000 features, 500 samples, and a binary class target +variable. Within each feature set, there is 500 features associated with one feature +set, and another 500 features associated with another feature set. One could think of +these for example as different datasets collected on the same patient in a biomedical setting. +The first feature set (X) is strongly correlated with the target, and the second +feature set (W) is weakly correlated with the target (y). Here, we are testing the +null hypothesis: + +``H0: I(X; y) - I(X, W; y) = 0`` +``HA: I(X; y) - I(X, W; y) > 0`` + +where ``I`` is mutual information. + +We present causal settings where this would be true: + +- ``W X -> y``: here ``W`` is completely disconnected from X and y. +- ``W -> X -> y``: here ``W`` is d-separated from y given X. +- ``W -> y <- X``: here ``W`` is a weak predictor of y, and X is a strong predictor of y. +- ``X <- W -> y; X -> y``: here ``W`` is a weak confounder of the relationship between X and y. + +We then use MIGHT to test the hypothesis that the first feature set is important for +predicting the target, and the second feature set is not important for predicting the +target. We use :class:`~sktree.FeatureImportanceForestClassifier`. +""" + +import numpy as np +from scipy.special import expit + +from sktree import HonestForestClassifier +from sktree.stats import FeatureImportanceForestClassifier +from sktree.tree import DecisionTreeClassifier + +seed = 12345 +rng = np.random.default_rng(seed) + +# %% +# Simulate data +# ------------- +# We simulate the two feature sets, and the target variable. We then combine them +# into a single dataset to perform hypothesis testing. + +n_samples = 1000 +n_features_set = 500 +mean = 1.0 +sigma = 2.0 +beta = 5.0 + +unimportant_mean = 0.0 +unimportant_sigma = 4.5 + +# first sample the informative features, and then the uniformative features +X_important = rng.normal(loc=mean, scale=sigma, size=(n_samples, 10)) +X_important = np.hstack( + [ + X_important, + rng.normal( + loc=unimportant_mean, scale=unimportant_sigma, size=(n_samples, n_features_set - 10) + ), + ] +) + +X_unimportant = rng.normal( + loc=unimportant_mean, scale=unimportant_sigma, size=(n_samples, n_features_set) +) +X = np.hstack([X_important, X_unimportant]) + +# simulate the binary target variable +y = rng.binomial(n=1, p=expit(beta * X_important[:, :10].sum(axis=1)), size=n_samples) + +# %% +# Perform hypothesis testing using Mutual Information +# --------------------------------------------------- +n_estimators = 100 +max_features = 1.0 +test_size = 0.2 +n_repeats = 500 + +# TODO: This can be improved since HonestForestClassifier should be able to extract +# the relevant hyperparameters +est = FeatureImportanceForestClassifier( + estimator=HonestForestClassifier( + n_estimators=n_estimators, + max_features=max_features, + tree_estimator=DecisionTreeClassifier(), + random_state=seed, + honest_fraction=0.7, + ), + random_state=seed, + test_size=test_size, + permute_per_tree=True, + sample_dataset_per_tree=True, +) + +# we test for the first feature set, which is important and thus should return a pvalue < 0.05 +stat, pvalue = est.test( + X, y, covariate_index=np.arange(n_features_set, dtype=int), metric="mi", n_repeats=n_repeats +) +print(f"Estimated MI difference: {stat} with Pvalue: {pvalue}") + +# we test for the second feature set, which is important and thus should return a pvalue > 0.05 +stat, pvalue = est.test( + X, + y, + covariate_index=np.arange(n_features_set, dtype=int) + n_features_set, + metric="mi", + n_repeats=n_repeats, +) +print(f"Estimated MI difference: {stat} with Pvalue: {pvalue}") diff --git a/sktree/stats/forestht.py b/sktree/stats/forestht.py index df37c7b62..57ab24ea4 100644 --- a/sktree/stats/forestht.py +++ b/sktree/stats/forestht.py @@ -377,6 +377,9 @@ class FeatureImportanceForestRegressor(BaseForestHT): verbose : int, default=0 Controls the verbosity when fitting and predicting. + test_size : float, default=0.2 + Proportion of samples per tree to use for the test set. + permute_per_tree : bool, default=True Whether to permute the covariate index per tree or per forest. @@ -554,6 +557,9 @@ class FeatureImportanceForestClassifier(BaseForestHT): verbose : int, default=0 Controls the verbosity when fitting and predicting. + test_size : float, default=0.2 + Proportion of samples per tree to use for the test set. + permute_per_tree : bool, default=True Whether to permute the covariate index per tree or per forest. diff --git a/sktree/stats/permutationforest.py b/sktree/stats/permutationforest.py index 026b89eca..6512f28ae 100644 --- a/sktree/stats/permutationforest.py +++ b/sktree/stats/permutationforest.py @@ -387,9 +387,6 @@ class PermutationForestClassifier(BasePermutationForest): test_size : float, default=0.2 The proportion of samples to leave out for each tree to compute metric on. - n_jobs : int, default=None - The number of jobs to run in parallel. - random_state : int, RandomState instance or None, default=None Controls both the randomness of the bootstrapping of the samples used when building trees (if ``bootstrap=True``) and the sampling of the diff --git a/sktree/stats/tests/test_forestht.py b/sktree/stats/tests/test_forestht.py index f015a0273..ba020a403 100644 --- a/sktree/stats/tests/test_forestht.py +++ b/sktree/stats/tests/test_forestht.py @@ -29,6 +29,50 @@ iris_y = iris_y[p] +def test_featureimportance_forest_permute_pertree(): + est = FeatureImportanceForestClassifier( + estimator=RandomForestClassifier( + n_estimators=10, + ), + permute_per_tree=True, + sample_dataset_per_tree=True, + ) + est.statistic(iris_X[:10], iris_y[:10]) + + assert ( + len(est.train_test_samples_[0][1]) == 10 * est.test_size + ), f"{len(est.train_test_samples_[0][1])} {10 * est.test_size}" + assert len(est.train_test_samples_[0][0]) == est._n_samples_ - 10 * est.test_size + + est.test(iris_X[:10], iris_y[:10], [0, 1], n_repeats=10, metric="mse") + assert ( + len(est.train_test_samples_[0][1]) == 10 * est.test_size + ), f"{len(est.train_test_samples_[0][1])} {10 * est.test_size}" + assert len(est.train_test_samples_[0][0]) == est._n_samples_ - 10 * est.test_size + + +def test_featureimportance_forest_errors(): + permute_per_tree = False + sample_dataset_per_tree = True + est = FeatureImportanceForestClassifier( + estimator=RandomForestClassifier( + n_estimators=10, + ), + permute_per_tree=permute_per_tree, + sample_dataset_per_tree=sample_dataset_per_tree, + ) + with pytest.raises(ValueError, match="sample_dataset_per_tree"): + est.statistic(iris_X[:10], iris_y[:10]) + + est = FeatureImportanceForestClassifier(estimator=RandomForestRegressor) + with pytest.raises(RuntimeError, match="Estimator must be"): + est.statistic(iris_X[:10], iris_y[:10]) + + est = FeatureImportanceForestRegressor(estimator=RandomForestClassifier) + with pytest.raises(RuntimeError, match="Estimator must be"): + est.statistic(iris_X[:10], iris_y[:10]) + + @flaky(max_runs=3) @pytest.mark.slowtest @pytest.mark.parametrize( @@ -215,8 +259,12 @@ def test_correlated_logit_model(hypotester, model_kwargs, n_samples, n_repeats, ObliqueDecisionTreeClassifier(), ], ) -@pytest.mark.parametrize("limit", [0.05, 0.1]) -def test_iris_pauc_statistic(criterion, honest_prior, estimator, limit): +@pytest.mark.parametrize("permute_per_tree", [True, False]) +@pytest.mark.parametrize("sample_dataset_per_tree", [True, False]) +def test_iris_pauc_statistic( + criterion, honest_prior, estimator, permute_per_tree, sample_dataset_per_tree +): + limit = 0.1 max_features = "sqrt" n_repeats = 200 n_estimators = 100 @@ -234,14 +282,17 @@ def test_iris_pauc_statistic(criterion, honest_prior, estimator, limit): n_jobs=-1, ), test_size=test_size, - sample_dataset_per_tree=True, - permute_per_tree=True, + sample_dataset_per_tree=sample_dataset_per_tree, + permute_per_tree=permute_per_tree, ) # now add completely uninformative feature X = np.hstack((iris_X, rng.standard_normal(size=(iris_X.shape[0], 4)))) # test for unimportant feature set clf.reset() + if sample_dataset_per_tree and not permute_per_tree: + # test in another test + pytest.skip() stat, pvalue = clf.test( X, iris_y, diff --git a/sktree/stats/utils.py b/sktree/stats/utils.py index 9b1c3debe..9d06dc8fd 100644 --- a/sktree/stats/utils.py +++ b/sktree/stats/utils.py @@ -15,11 +15,30 @@ from sktree._lib.sklearn.tree import DecisionTreeClassifier -def _mutual_information(y_true, y_pred_proba): +def _mutual_information(y_true: ArrayLike, y_pred_proba: ArrayLike) -> float: + """Compute estimate of mutual information. + + Parameters + ---------- + y_true : ArrayLike of shape (n_samples,) + _description_ + y_pred_proba : ArrayLike of shape (n_samples, n_outputs) + Posterior probabilities. + + Returns + ------- + float : + The estimated MI. + """ + if y_true.squeeze().ndim != 1: + raise ValueError(f"y_true must be 1d, not {y_true.shape}") + + # entropy averaged over n_samples H_YX = np.mean(entropy(y_pred_proba, base=np.exp(1), axis=1)) + # empirical count of each class (n_classes) _, counts = np.unique(y_true, return_counts=True) H_Y = entropy(counts, base=np.exp(1)) - return max(H_Y - H_YX, 0) + return H_Y - H_YX METRIC_FUNCTIONS = { diff --git a/sktree/tree/__init__.py b/sktree/tree/__init__.py index c4a706c99..be8baf5db 100644 --- a/sktree/tree/__init__.py +++ b/sktree/tree/__init__.py @@ -1,3 +1,4 @@ +from .._lib.sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ._classes import ( ExtraObliqueDecisionTreeClassifier, ExtraObliqueDecisionTreeRegressor, @@ -22,4 +23,6 @@ "PatchObliqueDecisionTreeClassifier", "PatchObliqueDecisionTreeRegressor", "HonestTreeClassifier", + "DecisionTreeClassifier", + "DecisionTreeRegressor", ] diff --git a/sktree/tree/_honest_tree.py b/sktree/tree/_honest_tree.py index 4ffa72c0f..43abe93fe 100644 --- a/sktree/tree/_honest_tree.py +++ b/sktree/tree/_honest_tree.py @@ -1,10 +1,9 @@ # Authors: Ronan Perry, Sambit Panda, Haoyin Xu # Adopted from: https://github.com/neurodata/honest-forests -from copy import deepcopy - import numpy as np -from sklearn.base import ClassifierMixin, MetaEstimatorMixin, _fit_context +from sklearn.base import ClassifierMixin, MetaEstimatorMixin, _fit_context, clone +from sklearn.ensemble._base import _set_random_states from sklearn.utils.multiclass import _check_partial_fit_first_call, check_classification_targets from sklearn.utils.validation import check_is_fitted, check_X_y @@ -536,7 +535,7 @@ def _fit( _sample_weight[self.honest_indices_] = 0 - if not self.tree_estimator: + if self.tree_estimator is None: self.estimator_ = DecisionTreeClassifier( criterion=self.criterion, splitter=self.splitter, @@ -555,7 +554,28 @@ def _fit( ) else: # XXX: maybe error out if the tree_estimator is already fitted - self.estimator_ = deepcopy(self.tree_estimator) + self.estimator_ = clone(self.tree_estimator) + self.estimator_.set_params( + **dict( + criterion=self.criterion, + splitter=self.splitter, + max_depth=self.max_depth, + min_samples_split=self.min_samples_split, + min_samples_leaf=self.min_samples_leaf, + min_weight_fraction_leaf=self.min_weight_fraction_leaf, + max_features=self.max_features, + max_leaf_nodes=self.max_leaf_nodes, + class_weight=self.class_weight, + random_state=self.random_state, + min_impurity_decrease=self.min_impurity_decrease, + ccp_alpha=self.ccp_alpha, + monotonic_cst=self.monotonic_cst, + store_leaf_values=self.store_leaf_values, + ) + ) + + if self.random_state is not None: + _set_random_states(self.estimator_, self.random_state) # Learn structure on subsample self.estimator_._fit(