Skip to content

Commit

Permalink
feat: add multiclass capability to decision trees
Browse files Browse the repository at this point in the history
  • Loading branch information
bcm-at-zama committed Apr 8, 2022
1 parent 3ffe34b commit 6f9651e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 37 deletions.
50 changes: 15 additions & 35 deletions src/concrete/ml/sklearn/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,7 @@ def fit(self, X: numpy.ndarray, y: numpy.ndarray, *args, **kwargs):
**kwargs: kwargs for super().fit
"""
qX = numpy.zeros_like(X)
# Check that there are only 2 classes
assert_true(
len(numpy.unique(numpy.asarray(y).flatten())) == 2,
"Only 2 classes are supported currently.",
)
# Check that the classes are 0 and 1
assert_true(
bool(numpy.all(numpy.unique(y.ravel()) == [0, 1])),
"y must be in [0, 1]",
)

self.q_x_byfeatures = []
# Quantization of each feature in X
for i in range(X.shape[1]):
Expand Down Expand Up @@ -264,14 +255,17 @@ def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
# mypy
assert self.q_y is not None
y_preds = self.q_y.update_quantized_values(y_preds)
y_preds = numpy.squeeze(y_preds)
assert_true(y_preds.ndim > 1, "y_preds should be a 2D array")
# Check if values are already probabilities
if any(numpy.abs(numpy.sum(y_preds, axis=1) - 1) > 1e-4):
# Apply softmax
# FIXME, https://github.com/zama-ai/concrete-ml-internal/issues/518, remove no-cover's
y_preds = numpy.exp(y_preds) # pragma: no cover
y_preds = y_preds / y_preds.sum(axis=1, keepdims=True) # pragma: no cover

# Make sure the shape of y_preds has 3 dimensions(n_tree, n_samples, n_classes)
# and here n_tree = 1.
assert_true(
(y_preds.ndim == 3) and (y_preds.shape[0] == 1),
f"Wrong dimensions for y_preds: {y_preds.shape} "
f"when is should have shape (1, n_samples, n_classes)",
)

# Remove the first dimension in y_preds
y_preds = y_preds[0]
return y_preds

# pylint: disable=arguments-differ
Expand Down Expand Up @@ -344,16 +338,12 @@ def _execute_in_fhe(self, X: numpy.ndarray) -> numpy.ndarray:
f"You must call {self.compile.__name__} "
f"before calling {self.predict.__name__} with execute_in_fhe=True.",
)
y_preds = numpy.zeros((qX.shape[0], self.n_classes_), dtype=numpy.int32)
y_preds = numpy.zeros((1, qX.shape[0], self.n_classes_), dtype=numpy.int32)
for i in range(qX.shape[0]):
# FIXME transpose workaround see #292
# expected x shape is (n_features, n_samples)
fhe_pred = self.fhe_tree.run(qX[i].astype(numpy.uint8).reshape(qX[i].shape[0], 1))
# Shape of y_pred is (n_trees, classes, n_examples)
# For a single decision tree we can squeeze the first dimension
# and get a shape of (classes, n_examples)
fhe_pred = numpy.squeeze(fhe_pred, axis=0)
y_preds[i, :] = fhe_pred.transpose()
y_preds[:, i, :] = numpy.transpose(fhe_pred, axes=(0, 2, 1))
return y_preds

def _predict_with_tensors(self, X: numpy.ndarray) -> numpy.ndarray:
Expand All @@ -379,17 +369,7 @@ def _predict_with_tensors(self, X: numpy.ndarray) -> numpy.ndarray:
qX = qX.T
y_pred = self._tensor_tree_predict(qX)[0]

# Shape of y_pred is (n_trees, classes, n_examples)
# For a single decision tree we can squeeze the first dimension
# and get a shape of (classes, n_examples)
y_pred = numpy.squeeze(y_pred, axis=0)

# Transpose and reshape should be applied in clear.
assert_true(
(y_pred.shape[0] == self.n_classes_) and (y_pred.shape[1] == qX.shape[1]),
"y_pred should have shape (n_classes, n_examples)",
)
y_pred = y_pred.transpose()
y_pred = numpy.transpose(y_pred, axes=(0, 2, 1))
return y_pred

# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/365
Expand Down
16 changes: 14 additions & 2 deletions tests/sklearn/test_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
),
id="make_classification",
),
pytest.param(
lambda: make_classification(
n_samples=100,
n_features=10,
n_classes=4,
n_informative=10,
n_redundant=0,
random_state=numpy.random.randint(0, 2**15),
),
id="make_classification_multiclass",
),
],
)
@pytest.mark.parametrize("use_virtual_lib", [True, False])
Expand Down Expand Up @@ -63,13 +74,14 @@ def test_decision_tree_classifier(
for value in values # type: ignore
],
)
def test_decision_tree_hyperparameters(hyperparameters, check_accuracy, check_r2_score):
@pytest.mark.parametrize("n_classes", [2, 4])
def test_decision_tree_hyperparameters(hyperparameters, n_classes, check_accuracy, check_r2_score):
"""Test that the hyperparameters are valid."""
x, y = make_classification(
n_samples=1000,
n_features=10,
n_informative=5,
n_classes=2,
n_classes=n_classes,
random_state=numpy.random.randint(0, 2**15),
)
model = DecisionTreeClassifier(
Expand Down

0 comments on commit 6f9651e

Please sign in to comment.