Skip to content

Commit

Permalink
FIX revert to sample weight method
Browse files Browse the repository at this point in the history
  • Loading branch information
PSSF23 committed Aug 22, 2023
1 parent 4a8f9c8 commit 6620f37
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions sktree/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,11 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None):
)
self.honest_indices_ = np.setdiff1d(nonzero_indices, self.structure_indices_)

_X = X[self.structure_indices_]
_y = y[self.structure_indices_]
_sample_weight = _sample_weight[self.structure_indices_]
_sample_weight[self.honest_indices_] = 0

self.estimator_.partial_fit(
_X,
_y,
X,
y,
sample_weight=_sample_weight,
check_input=check_input,
classes=classes if classes else np.unique(y),
Expand Down Expand Up @@ -535,9 +533,7 @@ def _fit(
)
self.honest_indices_ = np.setdiff1d(nonzero_indices, self.structure_indices_)

_X = X[self.structure_indices_]
_y = y[self.structure_indices_]
_sample_weight = _sample_weight[self.structure_indices_]
_sample_weight[self.honest_indices_] = 0

if not self.tree_estimator:
self.estimator_ = DecisionTreeClassifier(
Expand All @@ -562,8 +558,8 @@ def _fit(

# Learn structure on subsample
self.estimator_._fit(
_X,
_y,
X,
y,
sample_weight=_sample_weight,
check_input=check_input,
missing_values_in_feature_mask=missing_values_in_feature_mask,
Expand Down

0 comments on commit 6620f37

Please sign in to comment.