From 01c1e552c781ec02cdbb5a66851d1ff1737f41c1 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 22 Aug 2023 14:26:27 -0400 Subject: [PATCH] FIX attempt to fix list index --- sktree/tree/_honest_tree.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sktree/tree/_honest_tree.py b/sktree/tree/_honest_tree.py index 14f90cd99..cd91254f7 100644 --- a/sktree/tree/_honest_tree.py +++ b/sktree/tree/_honest_tree.py @@ -404,6 +404,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None): classes=classes, ) return self + rng = np.random.default_rng(self.random_state) if sample_weight is None: @@ -423,7 +424,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None): _sample_weight[self.honest_indices_] = 0 if classes is None: - classes = np.unique(y) + classes = np.unique(y).tolist() self.estimator_.partial_fit( X, @@ -560,7 +561,7 @@ def _fit( self.estimator_ = deepcopy(self.tree_estimator) if classes is None: - classes = np.unique(y) + classes = np.unique(y).tolist() # Learn structure on subsample self.estimator_._fit(