Skip to content

Commit

Permalink
FEA Add warning to control against runtime might when wanting to run …
Browse files Browse the repository at this point in the history
…comight (#323)

* Add warning to control against runtime might when wanting to run comight
* Updating submodule

---------

Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 authored Oct 10, 2024
1 parent 980de16 commit 6eaf495
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 11 deletions.
2 changes: 1 addition & 1 deletion doc/sphinxext/allow_nan_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from docutils import nodes
from docutils.parsers.rst import Directive
from sklearn.utils import all_estimators
from sklearn.utils._test_common.instance_generator import _construct_instances
from sklearn.utils._testing import SkipTest
from sklearn.utils.estimator_checks import _construct_instance


class AllowNanEstimators(Directive):
Expand Down
2 changes: 1 addition & 1 deletion treeple/_lib/sklearn_fork
Submodule sklearn_fork updated 203 files
19 changes: 13 additions & 6 deletions treeple/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.base import _fit_context, clone
from sklearn.ensemble._base import _partition_estimators, _set_random_states
from sklearn.utils import compute_sample_weight, resample
from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
from sklearn.utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions
from sklearn.utils.validation import check_is_fitted

from .._lib.sklearn.ensemble._forest import ForestClassifier
Expand Down Expand Up @@ -417,11 +417,18 @@ class labels (multi-output problem).
Interval(RealNotInt, 0.0, None, closed="right"),
Interval(Integral, 1, None, closed="left"),
]
_parameter_constraints["honest_fraction"] = [Interval(RealNotInt, 0.0, 1.0, closed="both")]
_parameter_constraints["honest_prior"] = [
StrOptions({"empirical", "uniform", "ignore"}),
]
_parameter_constraints["stratify"] = ["boolean"]
_parameter_constraints.update(
{
"tree_estimator": [
HasMethods(["fit", "predict", "predict_proba", "apply"]),
None,
],
"honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="neither")],
"honest_prior": [StrOptions({"empirical", "uniform", "ignore"})],
"stratify": ["boolean"],
"tree_estimator_params": ["dict"],
}
)

def __init__(
self,
Expand Down
18 changes: 16 additions & 2 deletions treeple/stats/forest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import threading
from collections import namedtuple
from typing import Callable
from warnings import warn

import numpy as np
from joblib import Parallel, delayed
Expand All @@ -10,6 +11,8 @@
from sklearn.utils.multiclass import type_of_target

from .._lib.sklearn.ensemble._forest import ForestClassifier
from ..ensemble import HonestForestClassifier
from ..tree import MultiViewDecisionTreeClassifier
from ..tree._classes import DTYPE
from .permuteforest import PermutationHonestForestClassifier
from .utils import METRIC_FUNCTIONS, POSITIVE_METRICS, _compute_null_distribution_coleman
Expand Down Expand Up @@ -38,8 +41,8 @@ def _parallel_predict_proba_oob(predict_proba, X, out, idx, test_idx, lock):


def build_coleman_forest(
est,
perm_est,
est: HonestForestClassifier,
perm_est: PermutationHonestForestClassifier,
X,
y,
covariate_index=None,
Expand Down Expand Up @@ -111,13 +114,24 @@ def build_coleman_forest(
"""
metric_func: Callable[[ArrayLike, ArrayLike], float] = METRIC_FUNCTIONS[metric]

if not isinstance(est, HonestForestClassifier):
raise RuntimeError(f"Original forest must be a HonestForestClassifier, got {type(est)}")

# build two sets of forests
est, orig_forest_proba = build_oob_forest(est, X, y, verbose=verbose)

if not isinstance(perm_est, PermutationHonestForestClassifier):
raise RuntimeError(
f"Permutation forest must be a PermutationHonestForestClassifier, got {type(perm_est)}"
)

if covariate_index is None and isinstance(est.tree_estimator, MultiViewDecisionTreeClassifier):
warn(
"Covariate index is not defined, but a MultiViewDecisionTreeClassifier is used. "
"If using CoMIGHT, one should define the covariate index to permute. "
"Defaulting to use MIGHT."
)

perm_est, perm_forest_proba = build_oob_forest(
perm_est, X, y, verbose=verbose, covariate_index=covariate_index
)
Expand Down
27 changes: 27 additions & 0 deletions treeple/stats/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,30 @@ def test_build_oob_random_forest():
assert len(np.unique(structure_samples[tree_idx])) + len(oob_samples_list[tree_idx]) == len(
samples
), f"{tree_idx} {len(structure_samples[tree_idx])} + {len(oob_samples_list[tree_idx])} != {len(samples)}"


def test_build_coleman_warning_with_multiview_without_covariate_index():
"""Test warning is raised in build_coleman_forest with multiview without covariate_index."""

est = HonestForestClassifier(
n_estimators=100,
random_state=0,
bootstrap=True,
max_samples=1.0,
honest_fraction=0.5,
stratify=True,
tree_estimator=MultiViewDecisionTreeClassifier(),
)
perm_est = PermutationHonestForestClassifier(
n_estimators=100,
random_state=0,
bootstrap=True,
max_samples=1.0,
honest_fraction=0.5,
stratify=True,
tree_estimator=MultiViewDecisionTreeClassifier(),
)
X = rng.normal(0, 1, (100, 2))
y = np.array([0, 1] * 50)
with pytest.warns(UserWarning, match="Covariate index is not defined"):
build_coleman_forest(est, perm_est, X, y, metric="s@98", n_repeats=1000, seed=0)
1 change: 1 addition & 0 deletions treeple/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def test_sklearn_compatible_estimator(estimator, check):
# TODO: this is an error. Somehow a segfault is raised when fit is called first and
# then partial_fit
"check_fit_score_takes_y",
"check_do_not_raise_errors_in_init_or_set_params",
]:
pytest.skip()
check(estimator)
Expand Down
1 change: 1 addition & 0 deletions treeple/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class frequency in the voting subsample.
"honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="neither")],
"honest_prior": [StrOptions({"empirical", "uniform", "ignore"})],
"stratify": ["boolean"],
"tree_estimator_params": ["dict"],
}

def __init__(
Expand Down
6 changes: 5 additions & 1 deletion treeple/tree/tests/test_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def test_sklearn_compatible_estimator(estimator, check):
# XXX: can include this "generalization" in the future if it's useful
# zero sample weight is not "really supported" in honest subsample trees since sample weight
# for fitting the tree's splits
if check.func.__name__ in ["check_class_weight_classifiers", "check_classifier_multioutput"]:
if check.func.__name__ in [
"check_class_weight_classifiers",
"check_classifier_multioutput",
"check_do_not_raise_errors_in_init_or_set_params",
]:
pytest.skip()
check(estimator)

Expand Down

0 comments on commit 6eaf495

Please sign in to comment.