diff --git a/sktree/tree/_honest_tree.py b/sktree/tree/_honest_tree.py index 14f90cd99..cd91254f7 100644 --- a/sktree/tree/_honest_tree.py +++ b/sktree/tree/_honest_tree.py @@ -404,6 +404,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None): classes=classes, ) return self + rng = np.random.default_rng(self.random_state) if sample_weight is None: @@ -423,7 +424,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None): _sample_weight[self.honest_indices_] = 0 if classes is None: - classes = np.unique(y) + classes = np.unique(y).tolist() self.estimator_.partial_fit( X, @@ -560,7 +561,7 @@ def _fit( self.estimator_ = deepcopy(self.tree_estimator) if classes is None: - classes = np.unique(y) + classes = np.unique(y).tolist() # Learn structure on subsample self.estimator_._fit(