Skip to content

Commit

Permalink
Fix pvalue sampling
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Sep 28, 2023
1 parent 64f2017 commit b8dc3a2
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 91 deletions.
204 changes: 134 additions & 70 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Dict

import numpy as np
from numpy.typing import ArrayLike
from sklearn.base import MetaEstimatorMixin, clone, is_classifier
Expand All @@ -8,7 +10,6 @@
from sklearn.utils.validation import _is_fitted, check_X_y

from sktree._lib.sklearn.ensemble._forest import (
BaseForest,
ForestClassifier,
ForestRegressor,
RandomForestClassifier,
Expand Down Expand Up @@ -109,12 +110,12 @@ def train_test_samples_(self):

def _statistic(
self,
estimator: BaseForest,
estimator: ForestClassifier,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
metric="mse",
return_posteriors: bool = False,
covariate_index: ArrayLike,
metric: str,
return_posteriors: bool,
**metric_kwargs,
):
raise NotImplementedError("Subclasses should implement this!")
Expand Down Expand Up @@ -175,7 +176,8 @@ def statistic(
-------
stat : float
The test statistic.
posterior_final : ArrayLike of shape (n_samples_final, n_outputs), optional
posterior_final : ArrayLike of shape (n_estimators, n_samples_final, n_outputs) or
(n_estimators, n_samples_final), optional
If ``return_posteriors`` is True, then the posterior probabilities of the
samples used in the final test. ``n_samples_final`` is equal to ``n_samples``
if all samples are encountered in the test set of at least one tree in the
Expand Down Expand Up @@ -332,6 +334,7 @@ def test(
# Note: at this point, both `estimator` and `permuted_estimator_` should
# have been fitted already, so we can now compute on the null by resampling
# the posteriors and computing the test statistic on the resampled posteriors
print(observe_posteriors.shape, permute_posteriors.shape)
if self.sample_dataset_per_tree:
metric_star, metric_star_pi = _compute_null_distribution_coleman(
y_test=y[observe_samples, :],
Expand All @@ -342,11 +345,11 @@ def test(
seed=self.random_state,
)
else:
if not self.sample_dataset_per_tree:
_, indices_test = self.train_test_samples_[0]
y_test = y[indices_test, :]
else:
y_test = y
# If not sampling a new dataset per tree, then we may either be
# permuting the covariate index per tree or per forest. If not permuting
# there is only one train and test split, so we can just use that
_, indices_test = self.train_test_samples_[0]
y_test = y[indices_test, :]
metric_star, metric_star_pi = _compute_null_distribution_coleman(
y_test=y_test,
y_pred_proba_normal=observe_posteriors,
Expand Down Expand Up @@ -439,7 +442,7 @@ class FeatureImportanceForestRegressor(BaseForestHT):
y_true_final_ : ArrayLike of shape (n_samples_final,)
The true labels of the samples used in the final test.
posterior_final_ : ArrayLike of shape (n_samples_final,)
posterior_final_ : ArrayLike of shape (n_estimators, n_samples_final)
The predicted posterior probabilities of the samples used in the final test.
null_dist_ : ArrayLike of shape (n_repeats,)
Expand Down Expand Up @@ -490,16 +493,16 @@ def _get_estimator(self):

def _statistic(
self,
estimator: ForestRegressor,
estimator: ForestClassifier,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
metric="mse",
return_posteriors: bool = False,
covariate_index: ArrayLike,
metric: str,
return_posteriors: bool,
**metric_kwargs,
):
"""Helper function to compute the test statistic."""
metric_func = METRIC_FUNCTIONS[metric]
metric_func: Callable[[ArrayLike, ArrayLike, Dict], float] = METRIC_FUNCTIONS[metric]
rng = np.random.default_rng(self.random_state)

if self.permute_per_tree:
Expand All @@ -516,18 +519,17 @@ def _statistic(
# Fill test set posteriors & set rest NaN
posterior_arr[idx, indices_test, :] = y_pred # posterior

y_true_final = y[indices_test, :]

# determine if there are any nans in the final posterior array
# Average all posteriors (n_samples_test, n_outputs)
posterior_final = np.nanmean(posterior_arr, axis=0)
posterior_forest = np.nanmean(posterior_arr, axis=0)

# Find the row indices with NaN values in any column
nonnan_indices = np.where(~np.isnan(posterior_final).any(axis=1))[0]
# # Find the row indices with NaN values in any column
nonnan_indices = np.where(~np.isnan(posterior_forest).any(axis=1))[0]
samples = nonnan_indices

# Ignore all NaN values (samples not tested)
# # Ignore all NaN values (samples not tested)
y_true_final = y[nonnan_indices, :]
posterior_final = posterior_final[nonnan_indices, :]
posterior_arr = posterior_arr[:, (nonnan_indices), :]
else:
# fitting a forest will only get one unique train/test split
indices_train, indices_test = self.train_test_samples_[0]
Expand All @@ -547,24 +549,31 @@ def _statistic(
X_train[:, covariate_index] = X_train[index_arr, covariate_index]

estimator.fit(X_train, y_train)
y_pred = estimator.predict(X_test)

# construct posterior array for all trees (n_trees, n_samples_test, n_outputs)
posterior_arr = np.full(
(len(estimator.estimators_), self.n_samples_test_, estimator.n_outputs_), np.nan
)
for itree, tree in enumerate(estimator.estimators_):
posterior_arr[itree, ...] = tree.predict(X_test)

# set variables to compute metric
samples = indices_test
y_true_final = y_test
posterior_final = y_pred

stat = metric_func(y_true_final, posterior_final, **metric_kwargs)
# Average all posteriors (n_samples_test, n_outputs) to compute the statistic
posterior_forest = np.nanmean(posterior_arr, axis=0)
stat = metric_func(y_true_final, posterior_forest, **metric_kwargs)
if covariate_index is None:
# Ignore all NaN values (samples not tested) -> (n_samples_final, n_outputs)
# arrays of y and predicted posterior
self.samples_ = samples
self.y_true_final_ = y_true_final
self.posterior_final_ = posterior_final
self.posterior_final_ = posterior_arr
self.stat_ = stat

if return_posteriors:
return stat, posterior_final, samples
return stat, posterior_arr, samples

return stat

Expand Down Expand Up @@ -625,12 +634,13 @@ class FeatureImportanceForestClassifier(BaseForestHT):
samples_ : ArrayLike of shape (n_samples_final,)
The indices of the samples used in the final test set that would slice
the original ``(X, y)`` input.
the original ``(X, y)`` input along the rows.
y_true_final_ : ArrayLike of shape (n_samples_final,)
The true labels of the samples used in the final test.
posterior_final_ : ArrayLike of shape (n_samples_final,)
posterior_final_ : ArrayLike of shape (n_estimators, n_samples_final, n_outputs) or
(n_estimators, n_samples_final)
The predicted posterior probabilities of the samples used in the final test.
null_dist_ : ArrayLike of shape (n_repeats,)
Expand Down Expand Up @@ -659,7 +669,7 @@ def __init__(
verbose=0,
test_size=0.2,
permute_per_tree=True,
sample_dataset_per_tree=True,
sample_dataset_per_tree=False,
):
super().__init__(
estimator=estimator,
Expand All @@ -685,55 +695,44 @@ def _statistic(
estimator: ForestClassifier,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
metric="mi",
return_posteriors: bool = False,
covariate_index: ArrayLike,
metric: str,
return_posteriors: bool,
**metric_kwargs,
):
"""Helper function to compute the test statistic."""
metric_func = METRIC_FUNCTIONS[metric]
metric_func: Callable[[ArrayLike, ArrayLike, Dict], float] = METRIC_FUNCTIONS[metric]
rng = np.random.default_rng(self.random_state)

if metric in POSTERIOR_FUNCTIONS:
predict_posteriors = True
else:
predict_posteriors = False

if predict_posteriors:
# now initialize posterior array as (n_trees, n_samples_test, n_classes)
posterior_arr = np.full(
(self.n_estimators, self._n_samples_, estimator.n_classes_), np.nan
)
else:
# now initialize posterior array as (n_trees, n_samples_test, n_outputs)
posterior_arr = np.full(
(self.n_estimators, self._n_samples_, estimator.n_outputs_), np.nan
)
if self.permute_per_tree:
if predict_posteriors:
posterior_arr = np.full(
(self.n_estimators, self._n_samples_, estimator.n_classes_), np.nan
)
else:
# now initialize posterior array as (n_trees, n_samples_test, n_outputs)
posterior_arr = np.full(
(self.n_estimators, self._n_samples_, estimator.n_outputs_), np.nan
)

for idx, (indices_train, indices_test) in enumerate(self._get_estimators_indices()):
tree: DecisionTreeClassifier = estimator.estimators_[idx]
train_tree(tree, X[indices_train, :], y[indices_train, :], covariate_index)

if predict_posteriors:
# XXX: currently assumes n_outputs_ == 1
y_pred = tree.predict_proba(X[indices_test, :])
y_pred = tree.predict_proba(X[indices_test, :]).reshape(-1, tree.n_classes_)
else:
y_pred = tree.predict(X[indices_test, :]).reshape(-1, tree.n_outputs_)

# Fill test set posteriors & set rest NaN
# TODO: refactor so posterior_arr is just a large NaN array
posterior_arr[idx, indices_test, :] = y_pred # posterior

# Average all posteriors (n_samples_test, n_outputs)
posterior_final = np.nanmean(posterior_arr, axis=0)

# Find the row indices with NaN values in any column
nonnan_indices = np.where(~np.isnan(posterior_final).any(axis=1))[0]
samples = nonnan_indices

# Ignore all NaN values (samples not tested)
y_true_final = y[nonnan_indices, :]
posterior_final = posterior_final[nonnan_indices, :]
else:
# fitting a forest will only get one unique train/test split
indices_train, indices_test = self.train_test_samples_[0]
Expand All @@ -756,39 +755,104 @@ def _statistic(
y_train = y_train.ravel()
estimator.fit(X_train, y_train)

if predict_posteriors:
# XXX: currently assumes n_outputs_ == 1
y_pred = estimator.predict_proba(X_test)
else:
y_pred = estimator.predict(X_test)
# construct posterior array for all trees (n_trees, n_samples_test, n_outputs)
for itree, tree in enumerate(estimator.estimators_):
if predict_posteriors:
# XXX: currently assumes n_outputs_ == 1
posterior_arr[itree, indices_test, ...] = tree.predict_proba(X_test).reshape(
-1, tree.n_classes_
)
else:
posterior_arr[itree, indices_test, ...] = tree.predict(X_test).reshape(
-1, tree.n_outputs_
)

# set variables to compute metric
samples = indices_test
y_true_final = y_test
posterior_final = y_pred
if metric == "auc":
# at this point, posterior_final is the predicted posterior for only the positive class
# as more than one output is not supported.
if self._type_of_target_ == "binary":
posterior_final = posterior_final[:, 1]
posterior_arr = posterior_arr[..., (1,)]
else:
raise RuntimeError(
f"AUC metric is not supported for {self._type_of_target_} targets."
)

if np.isnan(posterior_final).any():
raise RuntimeError("NaN values encountered in posterior_final.")
# determine if there are any nans in the final posterior array
# Average all posteriors (n_samples_test, n_outputs)
posterior_forest = np.nanmean(posterior_arr, axis=0)

# # Find the row indices with NaN values in any column
nonnan_indices = np.where(~np.isnan(posterior_forest).any(axis=1))[0]
samples = nonnan_indices

# Ignore all NaN values (samples not tested)
y_true_final = y[(nonnan_indices), :]
posterior_arr = posterior_arr[:, (nonnan_indices), :]

# Average all posteriors (n_samples_test, n_outputs) to compute the statistic
posterior_forest = np.nanmean(posterior_arr, axis=0)
stat = metric_func(y_true_final, posterior_forest, **metric_kwargs)

stat = metric_func(y_true_final, posterior_final, **metric_kwargs)
if covariate_index is None:
# Ignore all NaN values (samples not tested) -> (n_samples_final, n_outputs)
# arrays of y and predicted posterior
self.samples_ = samples
self.y_true_final_ = y_true_final
self.posterior_final_ = posterior_final
self.posterior_final_ = posterior_arr
self.stat_ = stat

if return_posteriors:
return stat, posterior_final, samples
return stat, posterior_arr, samples

return stat

def statistic(
self,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
metric="mi",
return_posteriors: bool = False,
check_input: bool = True,
**metric_kwargs,
):
"""Compute the test statistic.
Parameters
----------
X : ArrayLike of shape (n_samples, n_features)
The data matrix.
y : ArrayLike of shape (n_samples, n_outputs)
The target matrix.
covariate_index : ArrayLike, optional of shape (n_covariates,)
The index array of covariates to shuffle, by default None.
metric : str, optional
The metric to compute, by default "mi", which computes Mutual Information.
return_posteriors : bool, optional
Whether or not to return the posteriors, by default False.
check_input : bool, optional
Whether or not to check the input, by default True.
**metric_kwargs : dict, optional
Additional keyword arguments to pass to the metric function.
Returns
-------
stat : float
The test statistic.
posterior_final : ArrayLike of shape (n_estimators, n_samples_final, n_outputs) or
(n_estimators, n_samples_final), optional
If ``return_posteriors`` is True, then the posterior probabilities of the
samples used in the final test. ``n_samples_final`` is equal to ``n_samples``
if all samples are encountered in the test set of at least one tree in the
posterior computation.
samples : ArrayLike of shape (n_samples_final,), optional
The indices of the samples used in the final test. ``n_samples_final`` is
equal to ``n_samples`` if all samples are encountered in the test set of at
least one tree in the posterior computation.
"""
return super().statistic(
X, y, covariate_index, metric, return_posteriors, check_input, **metric_kwargs
)
4 changes: 3 additions & 1 deletion sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_featureimportance_forest_permute_pertree():
n_estimators=10,
),
permute_per_tree=True,
sample_dataset_per_tree=True,
sample_dataset_per_tree=False,
)
est.statistic(iris_X[:10], iris_y[:10])

Expand Down Expand Up @@ -323,6 +323,8 @@ def test_iris_pauc_statistic(
score = clf.statistic(iris_X, iris_y, metric="auc", max_fpr=limit)
assert score >= 0.8, "Failed with pAUC: {0} for max fpr: {1}".format(score, limit)

assert isinstance(clf.estimator_, HonestForestClassifier)


@pytest.mark.parametrize(
"forest_hyppo",
Expand Down
Loading

0 comments on commit b8dc3a2

Please sign in to comment.