diff --git a/src/concrete/ml/sklearn/tree.py b/src/concrete/ml/sklearn/tree.py index 02f876c70..081e0019b 100644 --- a/src/concrete/ml/sklearn/tree.py +++ b/src/concrete/ml/sklearn/tree.py @@ -227,16 +227,7 @@ def fit(self, X: numpy.ndarray, y: numpy.ndarray, *args, **kwargs): **kwargs: kwargs for super().fit """ qX = numpy.zeros_like(X) - # Check that there are only 2 classes - assert_true( - len(numpy.unique(numpy.asarray(y).flatten())) == 2, - "Only 2 classes are supported currently.", - ) - # Check that the classes are 0 and 1 - assert_true( - bool(numpy.all(numpy.unique(y.ravel()) == [0, 1])), - "y must be in [0, 1]", - ) + self.q_x_byfeatures = [] # Quantization of each feature in X for i in range(X.shape[1]): @@ -264,14 +255,17 @@ def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray: # mypy assert self.q_y is not None y_preds = self.q_y.update_quantized_values(y_preds) - y_preds = numpy.squeeze(y_preds) - assert_true(y_preds.ndim > 1, "y_preds should be a 2D array") - # Check if values are already probabilities - if any(numpy.abs(numpy.sum(y_preds, axis=1) - 1) > 1e-4): - # Apply softmax - # FIXME, https://github.com/zama-ai/concrete-ml-internal/issues/518, remove no-cover's - y_preds = numpy.exp(y_preds) # pragma: no cover - y_preds = y_preds / y_preds.sum(axis=1, keepdims=True) # pragma: no cover + + # Make sure the shape of y_preds has 3 dimensions(n_tree, n_samples, n_classes) + # and here n_tree = 1. + assert_true( + (y_preds.ndim == 3) and (y_preds.shape[0] == 1), + f"Wrong dimensions for y_preds: {y_preds.shape} " + f"when is should have shape (1, n_samples, n_classes)", + ) + + # Remove the first dimension in y_preds + y_preds = y_preds[0] return y_preds # pylint: disable=arguments-differ @@ -344,16 +338,12 @@ def _execute_in_fhe(self, X: numpy.ndarray) -> numpy.ndarray: f"You must call {self.compile.__name__} " f"before calling {self.predict.__name__} with execute_in_fhe=True.", ) - y_preds = numpy.zeros((qX.shape[0], self.n_classes_), dtype=numpy.int32) + y_preds = numpy.zeros((1, qX.shape[0], self.n_classes_), dtype=numpy.int32) for i in range(qX.shape[0]): # FIXME transpose workaround see #292 # expected x shape is (n_features, n_samples) fhe_pred = self.fhe_tree.run(qX[i].astype(numpy.uint8).reshape(qX[i].shape[0], 1)) - # Shape of y_pred is (n_trees, classes, n_examples) - # For a single decision tree we can squeeze the first dimension - # and get a shape of (classes, n_examples) - fhe_pred = numpy.squeeze(fhe_pred, axis=0) - y_preds[i, :] = fhe_pred.transpose() + y_preds[:, i, :] = numpy.transpose(fhe_pred, axes=(0, 2, 1)) return y_preds def _predict_with_tensors(self, X: numpy.ndarray) -> numpy.ndarray: @@ -379,17 +369,7 @@ def _predict_with_tensors(self, X: numpy.ndarray) -> numpy.ndarray: qX = qX.T y_pred = self._tensor_tree_predict(qX)[0] - # Shape of y_pred is (n_trees, classes, n_examples) - # For a single decision tree we can squeeze the first dimension - # and get a shape of (classes, n_examples) - y_pred = numpy.squeeze(y_pred, axis=0) - - # Transpose and reshape should be applied in clear. - assert_true( - (y_pred.shape[0] == self.n_classes_) and (y_pred.shape[1] == qX.shape[1]), - "y_pred should have shape (n_classes, n_examples)", - ) - y_pred = y_pred.transpose() + y_pred = numpy.transpose(y_pred, axes=(0, 2, 1)) return y_pred # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/365 diff --git a/tests/sklearn/test_decision_tree.py b/tests/sklearn/test_decision_tree.py index 762d0e330..9cdd01e58 100644 --- a/tests/sklearn/test_decision_tree.py +++ b/tests/sklearn/test_decision_tree.py @@ -19,6 +19,17 @@ ), id="make_classification", ), + pytest.param( + lambda: make_classification( + n_samples=100, + n_features=10, + n_classes=4, + n_informative=10, + n_redundant=0, + random_state=numpy.random.randint(0, 2**15), + ), + id="make_classification_multiclass", + ), ], ) @pytest.mark.parametrize("use_virtual_lib", [True, False]) @@ -63,13 +74,14 @@ def test_decision_tree_classifier( for value in values # type: ignore ], ) -def test_decision_tree_hyperparameters(hyperparameters, check_accuracy, check_r2_score): +@pytest.mark.parametrize("n_classes", [2, 4]) +def test_decision_tree_hyperparameters(hyperparameters, n_classes, check_accuracy, check_r2_score): """Test that the hyperparameters are valid.""" x, y = make_classification( n_samples=1000, n_features=10, n_informative=5, - n_classes=2, + n_classes=n_classes, random_state=numpy.random.randint(0, 2**15), ) model = DecisionTreeClassifier(