diff --git a/cpp/src/glm/qn/mg/qn_mg.cuh b/cpp/src/glm/qn/mg/qn_mg.cuh index 425240f87d..ef9c1db6c2 100644 --- a/cpp/src/glm/qn/mg/qn_mg.cuh +++ b/cpp/src/glm/qn/mg/qn_mg.cuh @@ -49,9 +49,12 @@ int qn_fit_mg(const raft::handle_t& handle, SimpleVec w0(w0_data, loss.n_param); // Scale the regularization strength with the number of samples. - T l1 = 0; + T l1 = pams.penalty_l1; T l2 = pams.penalty_l2; - if (pams.penalty_normalized) { l2 /= n_samples; } + if (pams.penalty_normalized) { + l1 /= n_samples; + l2 /= n_samples; + } ML::GLM::detail::Tikhonov reg(l2); ML::GLM::detail::RegularizedGLM regularizer_obj(&loss, ®); diff --git a/python/cuml/linear_model/logistic_regression_mg.pyx b/python/cuml/linear_model/logistic_regression_mg.pyx index eecea81b04..3330541b32 100644 --- a/python/cuml/linear_model/logistic_regression_mg.pyx +++ b/python/cuml/linear_model/logistic_regression_mg.pyx @@ -89,8 +89,6 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): def __init__(self, **kwargs): super(LogisticRegressionMG, self).__init__(**kwargs) - if self.penalty != "l2" and self.penalty != "none": - assert False, "Currently only support 'l2' and 'none' penalty" @property @cuml.internals.api_base_return_array_skipall @@ -210,9 +208,14 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): self._num_classes, &objective32, &num_iters) + + self.solver_model.objective = objective32 + else: assert False, "dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589" + self.solver_model.num_iters = num_iters + self.solver_model._calc_intercept() self.handle.sync() diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py index 1df477d9c1..4f0cd7408b 100644 --- a/python/cuml/tests/dask/test_dask_logistic_regression.py +++ b/python/cuml/tests/dask/test_dask_logistic_regression.py @@ -20,6 +20,7 @@ from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression as skLR from cuml.internals.safe_imports import cpu_only_import +from cuml.testing.utils import array_equal pd = cpu_only_import("pandas") np = cpu_only_import("numpy") @@ -281,6 +282,8 @@ def test_lbfgs( delayed, client, penalty="l2", + l1_ratio=None, + C=1.0, n_classes=2, ): tolerance = 0.005 @@ -304,20 +307,42 @@ def imp(): X_df, y_df = _prep_training_data(client, X, y, n_parts) - lr = cumlLBFGS_dask(fit_intercept=fit_intercept, penalty=penalty) + lr = cumlLBFGS_dask( + solver="qn", + fit_intercept=fit_intercept, + penalty=penalty, + l1_ratio=l1_ratio, + C=C, + verbose=True, + ) lr.fit(X_df, y_df) lr_coef = lr.coef_.to_numpy() lr_intercept = lr.intercept_.to_numpy() - sk_model = skLR(fit_intercept=fit_intercept, penalty=penalty) + if penalty == "l2" or penalty == "none": + sk_solver = "lbfgs" + elif penalty == "l1" or penalty == "elasticnet": + sk_solver = "saga" + else: + raise ValueError(f"unexpected penalty {penalty}") + + sk_model = skLR( + solver=sk_solver, + fit_intercept=fit_intercept, + penalty=penalty, + l1_ratio=l1_ratio, + C=C, + ) sk_model.fit(X, y) sk_coef = sk_model.coef_ sk_intercept = sk_model.intercept_ - assert len(lr_coef) == len(sk_coef) - for i in range(len(lr_coef)): - assert lr_coef[i] == pytest.approx(sk_coef[i], abs=tolerance) - assert lr_intercept == pytest.approx(sk_intercept, abs=tolerance) + if sk_solver == "lbfgs": + assert len(lr_coef) == len(sk_coef) + assert array_equal(lr_coef, sk_coef, tolerance, with_sign=True) + assert array_equal( + lr_intercept, sk_intercept, tolerance, with_sign=True + ) # test predict cu_preds = lr.predict(X_df, delayed=delayed).compute().to_numpy() @@ -414,17 +439,55 @@ def test_n_classes(n_parts, fit_intercept, n_classes, client): assert lr._num_classes == n_classes -@pytest.mark.parametrize("penalty", ["l1", "elasticnet"]) -@pytest.mark.parametrize("l1_ratio", [0.1]) -def test_l1_and_elasticnet(penalty, l1_ratio, client): - X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], np.float32) - y = np.array([1.0, 1.0, 0.0, 0.0], np.float32) - X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1) +@pytest.mark.mg +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize("delayed", [True]) +@pytest.mark.parametrize("n_classes", [2, 8]) +@pytest.mark.parametrize("C", [1.0, 10.0]) +def test_l1(fit_intercept, datatype, delayed, n_classes, C, client): + lr = test_lbfgs( + nrows=1e5, + ncols=20, + n_parts=2, + fit_intercept=fit_intercept, + datatype=datatype, + delayed=delayed, + client=client, + penalty="l1", + n_classes=n_classes, + C=C, + ) + + l1_strength, l2_strength = lr._get_qn_params() + assert l1_strength == 1.0 / lr.C + assert l2_strength == 0.0 + - from cuml.dask.linear_model import LogisticRegression +@pytest.mark.mg +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize("delayed", [True]) +@pytest.mark.parametrize("n_classes", [2, 8]) +@pytest.mark.parametrize("l1_ratio", [0.2, 0.8]) +def test_elasticnet( + fit_intercept, datatype, delayed, n_classes, l1_ratio, client +): + lr = test_lbfgs( + nrows=1e5, + ncols=20, + n_parts=2, + fit_intercept=fit_intercept, + datatype=datatype, + delayed=delayed, + client=client, + penalty="elasticnet", + n_classes=n_classes, + l1_ratio=l1_ratio, + ) - lr = LogisticRegression(penalty=penalty, l1_ratio=l1_ratio) - with pytest.raises( - RuntimeError, match="Currently only support 'l2' and 'none' penalty" - ): - lr.fit(X_df, y_df) + l1_strength, l2_strength = lr._get_qn_params() + + strength = 1.0 / lr.C + assert l1_strength == lr.l1_ratio * strength + assert l2_strength == (1.0 - lr.l1_ratio) * strength