Skip to content

Commit

Permalink
Merge pull request #5602 from rapidsai/branch-23.10
Browse files Browse the repository at this point in the history
Forward-merge branch-23.10 to branch-23.12
  • Loading branch information
GPUtester authored Oct 4, 2023
2 parents e3dde52 + 39dfc7e commit cf7516d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 22 deletions.
7 changes: 5 additions & 2 deletions cpp/src/glm/qn/mg/qn_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ int qn_fit_mg(const raft::handle_t& handle,
SimpleVec<T> w0(w0_data, loss.n_param);

// Scale the regularization strength with the number of samples.
T l1 = 0;
T l1 = pams.penalty_l1;
T l2 = pams.penalty_l2;
if (pams.penalty_normalized) { l2 /= n_samples; }
if (pams.penalty_normalized) {
l1 /= n_samples;
l2 /= n_samples;
}

ML::GLM::detail::Tikhonov<T> reg(l2);
ML::GLM::detail::RegularizedGLM<T, LossFunction, decltype(reg)> regularizer_obj(&loss, &reg);
Expand Down
7 changes: 5 additions & 2 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

def __init__(self, **kwargs):
super(LogisticRegressionMG, self).__init__(**kwargs)
if self.penalty != "l2" and self.penalty != "none":
assert False, "Currently only support 'l2' and 'none' penalty"

@property
@cuml.internals.api_base_return_array_skipall
Expand Down Expand Up @@ -210,9 +208,14 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
self._num_classes,
<float*> &objective32,
<int*> &num_iters)

self.solver_model.objective = objective32

else:
assert False, "dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589"

self.solver_model.num_iters = num_iters

self.solver_model._calc_intercept()

self.handle.sync()
99 changes: 81 additions & 18 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression as skLR
from cuml.internals.safe_imports import cpu_only_import
from cuml.testing.utils import array_equal

pd = cpu_only_import("pandas")
np = cpu_only_import("numpy")
Expand Down Expand Up @@ -281,6 +282,8 @@ def test_lbfgs(
delayed,
client,
penalty="l2",
l1_ratio=None,
C=1.0,
n_classes=2,
):
tolerance = 0.005
Expand All @@ -304,20 +307,42 @@ def imp():

X_df, y_df = _prep_training_data(client, X, y, n_parts)

lr = cumlLBFGS_dask(fit_intercept=fit_intercept, penalty=penalty)
lr = cumlLBFGS_dask(
solver="qn",
fit_intercept=fit_intercept,
penalty=penalty,
l1_ratio=l1_ratio,
C=C,
verbose=True,
)
lr.fit(X_df, y_df)
lr_coef = lr.coef_.to_numpy()
lr_intercept = lr.intercept_.to_numpy()

sk_model = skLR(fit_intercept=fit_intercept, penalty=penalty)
if penalty == "l2" or penalty == "none":
sk_solver = "lbfgs"
elif penalty == "l1" or penalty == "elasticnet":
sk_solver = "saga"
else:
raise ValueError(f"unexpected penalty {penalty}")

sk_model = skLR(
solver=sk_solver,
fit_intercept=fit_intercept,
penalty=penalty,
l1_ratio=l1_ratio,
C=C,
)
sk_model.fit(X, y)
sk_coef = sk_model.coef_
sk_intercept = sk_model.intercept_

assert len(lr_coef) == len(sk_coef)
for i in range(len(lr_coef)):
assert lr_coef[i] == pytest.approx(sk_coef[i], abs=tolerance)
assert lr_intercept == pytest.approx(sk_intercept, abs=tolerance)
if sk_solver == "lbfgs":
assert len(lr_coef) == len(sk_coef)
assert array_equal(lr_coef, sk_coef, tolerance, with_sign=True)
assert array_equal(
lr_intercept, sk_intercept, tolerance, with_sign=True
)

# test predict
cu_preds = lr.predict(X_df, delayed=delayed).compute().to_numpy()
Expand Down Expand Up @@ -414,17 +439,55 @@ def test_n_classes(n_parts, fit_intercept, n_classes, client):
assert lr._num_classes == n_classes


@pytest.mark.parametrize("penalty", ["l1", "elasticnet"])
@pytest.mark.parametrize("l1_ratio", [0.1])
def test_l1_and_elasticnet(penalty, l1_ratio, client):
X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], np.float32)
y = np.array([1.0, 1.0, 0.0, 0.0], np.float32)
X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1)
@pytest.mark.mg
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("delayed", [True])
@pytest.mark.parametrize("n_classes", [2, 8])
@pytest.mark.parametrize("C", [1.0, 10.0])
def test_l1(fit_intercept, datatype, delayed, n_classes, C, client):
lr = test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=2,
fit_intercept=fit_intercept,
datatype=datatype,
delayed=delayed,
client=client,
penalty="l1",
n_classes=n_classes,
C=C,
)

l1_strength, l2_strength = lr._get_qn_params()
assert l1_strength == 1.0 / lr.C
assert l2_strength == 0.0


from cuml.dask.linear_model import LogisticRegression
@pytest.mark.mg
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("delayed", [True])
@pytest.mark.parametrize("n_classes", [2, 8])
@pytest.mark.parametrize("l1_ratio", [0.2, 0.8])
def test_elasticnet(
fit_intercept, datatype, delayed, n_classes, l1_ratio, client
):
lr = test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=2,
fit_intercept=fit_intercept,
datatype=datatype,
delayed=delayed,
client=client,
penalty="elasticnet",
n_classes=n_classes,
l1_ratio=l1_ratio,
)

lr = LogisticRegression(penalty=penalty, l1_ratio=l1_ratio)
with pytest.raises(
RuntimeError, match="Currently only support 'l2' and 'none' penalty"
):
lr.fit(X_df, y_df)
l1_strength, l2_strength = lr._get_qn_params()

strength = 1.0 / lr.C
assert l1_strength == lr.l1_ratio * strength
assert l2_strength == (1.0 - lr.l1_ratio) * strength

0 comments on commit cf7516d

Please sign in to comment.