diff --git a/ci/test_wheel.sh b/ci/test_wheel.sh index bc21b9fd81..e8403559d9 100755 --- a/ci/test_wheel.sh +++ b/ci/test_wheel.sh @@ -13,7 +13,7 @@ if [[ "$(arch)" == "aarch64" ]]; then fi # Always install latest dask for testing -python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.12 +python -m pip install git+https://github.com/dask/dask.git@2023.9.2 git+https://github.com/dask/distributed.git@2023.9.2 git+https://github.com/rapidsai/dask-cuda.git@branch-23.12 # echo to expand wildcard before adding `[extra]` requires for pip python -m pip install $(echo ./dist/cuml*.whl)[test] diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 69780f59d6..6038499d0c 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -16,12 +16,12 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-core>=2023.7.1 +- dask-core==2023.9.2 - dask-cuda==23.12.* - dask-cudf==23.12.* - dask-ml -- dask>=2023.7.1 -- distributed>=2023.7.1 +- dask==2023.9.2 +- distributed==2023.9.2 - doxygen=1.9.1 - gcc_linux-64=11.* - gmock>=1.13.0 diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml index ad93b46da8..2f81709f6e 100644 --- a/conda/environments/all_cuda-120_arch-x86_64.yaml +++ b/conda/environments/all_cuda-120_arch-x86_64.yaml @@ -18,12 +18,12 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-core>=2023.7.1 +- dask-core==2023.9.2 - dask-cuda==23.12.* - dask-cudf==23.12.* - dask-ml -- dask>=2023.7.1 -- distributed>=2023.7.1 +- dask==2023.9.2 +- distributed==2023.9.2 - doxygen=1.9.1 - gcc_linux-64=11.* - gmock>=1.13.0 diff --git a/conda/recipes/cuml/meta.yaml b/conda/recipes/cuml/meta.yaml index 2f2864f176..817776fa13 100644 --- a/conda/recipes/cuml/meta.yaml +++ b/conda/recipes/cuml/meta.yaml @@ -76,9 +76,9 @@ requirements: - cudf ={{ minor_version }} - cupy >=12.0.0 - dask-cudf ={{ minor_version }} - - dask >=2023.7.1 - - dask-core>=2023.7.1 - - distributed >=2023.7.1 + - dask ==2023.9.2 + - dask-core==2023.9.2 + - distributed ==2023.9.2 - joblib >=0.11 - libcuml ={{ version }} - libcumlprims ={{ minor_version }} diff --git a/cpp/include/cuml/linear_model/qn_mg.hpp b/cpp/include/cuml/linear_model/qn_mg.hpp index 89a79f0677..f70fd833e9 100644 --- a/cpp/include/cuml/linear_model/qn_mg.hpp +++ b/cpp/include/cuml/linear_model/qn_mg.hpp @@ -21,12 +21,24 @@ #include #include +#include using namespace MLCommon; namespace ML { namespace GLM { namespace opg { +/** + * @brief Calculate unique class labels across multiple GPUs in a multi-node environment. + * @param[in] handle: the internal cuml handle object + * @param[in] input_desc: PartDescriptor object for the input + * @param[in] labels: labels data + * @returns host vector that stores the distinct labels + */ +std::vector getUniquelabelsMG(const raft::handle_t& handle, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels); + /** * @brief performs MNMG fit operation for the logistic regression using quasi newton methods * @param[in] handle: the internal cuml handle object diff --git a/cpp/src/glm/qn/mg/qn_mg.cuh b/cpp/src/glm/qn/mg/qn_mg.cuh index a9a1ff1bf5..425240f87d 100644 --- a/cpp/src/glm/qn/mg/qn_mg.cuh +++ b/cpp/src/glm/qn/mg/qn_mg.cuh @@ -103,6 +103,12 @@ inline void qn_fit_x_mg(const raft::handle_t& handle, ML::GLM::opg::qn_fit_mg( handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks); } break; + case QN_LOSS_SOFTMAX: { + ASSERT(C > 2, "qn_mg.cuh: softmax invalid C"); + ML::GLM::detail::Softmax loss(handle, D, C, pams.fit_intercept); + ML::GLM::opg::qn_fit_mg( + handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks); + } break; default: { ASSERT(false, "qn_mg.cuh: unknown loss function type (id = %d).", pams.loss); } diff --git a/cpp/src/glm/qn_mg.cu b/cpp/src/glm/qn_mg.cu index eb809aa688..5a60c01f79 100644 --- a/cpp/src/glm/qn_mg.cu +++ b/cpp/src/glm/qn_mg.cu @@ -21,15 +21,59 @@ #include #include #include +#include #include #include +#include #include +#include using namespace MLCommon; namespace ML { namespace GLM { namespace opg { +template +std::vector distinct_mg(const raft::handle_t& handle, T* y, size_t n) +{ + cudaStream_t stream = handle.get_stream(); + raft::comms::comms_t const& comm = raft::resource::get_comms(handle); + int rank = comm.get_rank(); + int n_ranks = comm.get_size(); + + rmm::device_uvector unique_y(0, stream); + raft::label::getUniquelabels(unique_y, y, n, stream); + + rmm::device_uvector recv_counts(n_ranks, stream); + auto send_count = raft::make_device_scalar(handle, unique_y.size()); + comm.allgather(send_count.data_handle(), recv_counts.data(), 1, stream); + comm.sync_stream(stream); + + std::vector recv_counts_host(n_ranks); + raft::copy(recv_counts_host.data(), recv_counts.data(), n_ranks, stream); + + std::vector displs(n_ranks); + size_t pos = 0; + for (int i = 0; i < n_ranks; ++i) { + displs[i] = pos; + pos += recv_counts_host[i]; + } + + rmm::device_uvector recv_buff(displs.back() + recv_counts_host.back(), stream); + comm.allgatherv( + unique_y.data(), recv_buff.data(), recv_counts_host.data(), displs.data(), stream); + comm.sync_stream(stream); + + rmm::device_uvector global_unique_y(0, stream); + int n_distinct = + raft::label::getUniquelabels(global_unique_y, recv_buff.data(), recv_buff.size(), stream); + + std::vector global_unique_y_host(global_unique_y.size()); + raft::copy(global_unique_y_host.data(), global_unique_y.data(), global_unique_y.size(), stream); + + return global_unique_y_host; +} + template void qnFit_impl(const raft::handle_t& handle, const qn_params& pams, @@ -46,17 +90,6 @@ void qnFit_impl(const raft::handle_t& handle, int rank, int n_ranks) { - switch (pams.loss) { - case QN_LOSS_LOGISTIC: { - RAFT_EXPECTS( - C == 2, - "qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2"); - } break; - default: { - RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss); - } - } - auto X_simple = SimpleDenseMat(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR); ML::GLM::opg::qn_fit_x_mg(handle, @@ -113,6 +146,17 @@ void qnFit_impl(raft::handle_t& handle, input_desc.uniqueRanks().size()); } +std::vector getUniquelabelsMG(const raft::handle_t& handle, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels) +{ + RAFT_EXPECTS(labels.size() == 1, + "getUniqueLabelsMG currently does not accept more than one data chunk"); + Matrix::Data* data_y = labels[0]; + int n_rows = input_desc.totalElementsOwnedBy(input_desc.rank); + return distinct_mg(handle, data_y->ptr, n_rows); +} + void qnFit(raft::handle_t& handle, std::vector*>& input_data, Matrix::PartDescriptor& input_desc, diff --git a/dependencies.yaml b/dependencies.yaml index fe06dbd847..86307617b4 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -175,10 +175,10 @@ dependencies: - output_types: [conda, requirements, pyproject] packages: - cudf==23.12.* - - dask>=2023.7.1 + - dask==2023.9.2 - dask-cuda==23.12.* - dask-cudf==23.12.* - - distributed>=2023.7.1 + - distributed==2023.9.2 - joblib>=0.11 - numba>=0.57 # TODO: Is scipy really a hard dependency, or should @@ -192,7 +192,7 @@ dependencies: - cupy>=12.0.0 - output_types: conda packages: - - dask-core>=2023.7.1 + - dask-core==2023.9.2 - output_types: pyproject packages: - *treelite_runtime diff --git a/python/README.md b/python/README.md index bcb2bf6d17..41a1d366cd 100644 --- a/python/README.md +++ b/python/README.md @@ -70,8 +70,8 @@ Packages required for multigpu algorithms*: - ucx-py version matching the cuML version - dask-cudf version matching the cuML version - nccl>=2.5 -- dask>=2023.7.1 -- distributed>=2023.7.1 +- dask==2023.9.2 +- distributed==2023.9.2 * this can be avoided with `--singlegpu` argument flag. diff --git a/python/cuml/dask/linear_model/logistic_regression.py b/python/cuml/dask/linear_model/logistic_regression.py index d33388a654..38366a1b50 100644 --- a/python/cuml/dask/linear_model/logistic_regression.py +++ b/python/cuml/dask/linear_model/logistic_regression.py @@ -174,4 +174,11 @@ def _create_model(sessionId, datatype, **kwargs): def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank): inp_X = concatenate([X for X, _ in data]) inp_y = concatenate([y for _, y in data]) - return f.fit([(inp_X, inp_y)], n_rows, n_cols, partsToSizes, rank) + n_ranks = max([p[0] for p in partsToSizes]) + 1 + aggregated_partsToSizes = [[i, 0] for i in range(n_ranks)] + for p in partsToSizes: + aggregated_partsToSizes[p[0]][1] += p[1] + + return f.fit( + [(inp_X, inp_y)], n_rows, n_cols, aggregated_partsToSizes, rank + ) diff --git a/python/cuml/linear_model/logistic_regression_mg.pyx b/python/cuml/linear_model/logistic_regression_mg.pyx index fce58a4c59..eecea81b04 100644 --- a/python/cuml/linear_model/logistic_regression_mg.pyx +++ b/python/cuml/linear_model/logistic_regression_mg.pyx @@ -79,11 +79,18 @@ cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil: float *f, int *num_iters) except + + cdef vector[float] getUniquelabelsMG( + const handle_t& handle, + PartDescriptor &input_desc, + vector[floatData_t*] labels) except+ + class LogisticRegressionMG(MGFitMixin, LogisticRegression): def __init__(self, **kwargs): super(LogisticRegressionMG, self).__init__(**kwargs) + if self.penalty != "l2" and self.penalty != "none": + assert False, "Currently only support 'l2' and 'none' penalty" @property @cuml.internals.api_base_return_array_skipall @@ -102,8 +109,8 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): self.solver_model.coef_ = value - def prepare_for_fit(self, n_classes): - self.solver_model.qnparams = QNParams( + def create_qnparams(self): + return QNParams( loss=self.loss, penalty_l1=self.l1_strength, penalty_l2=self.l2_strength, @@ -118,8 +125,11 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): penalty_normalized=self.penalty_normalized ) + def prepare_for_fit(self, n_classes): + self.solver_model.qnparams = self.create_qnparams() + # modified - qnpams = self.qnparams.params + qnpams = self.solver_model.qnparams.params # modified qnp solves_classification = qnpams['loss'] in { @@ -174,8 +184,14 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): cdef float objective32 cdef int num_iters - # TODO: calculate _num_classes at runtime - self._num_classes = 2 + cdef vector[float] c_classes_ + c_classes_ = getUniquelabelsMG( + handle_[0], + deref(input_desc), + deref(y)) + self.classes_ = np.sort(list(c_classes_)).astype('float32') + + self._num_classes = len(self.classes_) self.loss = "sigmoid" if self._num_classes <= 2 else "softmax" self.prepare_for_fit(self._num_classes) cdef uintptr_t mat_coef_ptr = self.coef_.ptr @@ -194,6 +210,8 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): self._num_classes, &objective32, &num_iters) + else: + assert False, "dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589" self.solver_model._calc_intercept() diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py index 4a1bca911d..1df477d9c1 100644 --- a/python/cuml/tests/dask/test_dask_logistic_regression.py +++ b/python/cuml/tests/dask/test_dask_logistic_regression.py @@ -47,9 +47,13 @@ def _prep_training_data(c, X_train, y_train, partitions_per_worker): return X_train_df, y_train_df -def make_classification_dataset(datatype, nrows, ncols, n_info): +def make_classification_dataset(datatype, nrows, ncols, n_info, n_classes=2): X, y = make_classification( - n_samples=nrows, n_features=ncols, n_informative=n_info, random_state=0 + n_samples=nrows, + n_features=ncols, + n_informative=n_info, + n_classes=n_classes, + random_state=0, ) X = X.astype(datatype) y = y.astype(datatype) @@ -176,6 +180,16 @@ def imp(): assert_array_equal(preds, y, strict=True) + # assert error on float64 + X = X.astype(np.float64) + y = y.astype(np.float64) + X_df, y_df = _prep_training_data(client, X, y, n_parts) + with pytest.raises( + RuntimeError, + match="dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589", + ): + lr.fit(X_df, y_df) + def test_lbfgs_init(client): def imp(): @@ -267,6 +281,7 @@ def test_lbfgs( delayed, client, penalty="l2", + n_classes=2, ): tolerance = 0.005 @@ -283,7 +298,9 @@ def imp(): n_info = 5 nrows = int(nrows) ncols = int(ncols) - X, y = make_classification_dataset(datatype, nrows, ncols, n_info) + X, y = make_classification_dataset( + datatype, nrows, ncols, n_info, n_classes=n_classes + ) X_df, y_df = _prep_training_data(client, X, y, n_parts) @@ -303,12 +320,13 @@ def imp(): assert lr_intercept == pytest.approx(sk_intercept, abs=tolerance) # test predict - cu_preds = lr.predict(X_df, delayed=delayed) - accuracy_cuml = accuracy_score(y, cu_preds.compute().to_numpy()) + cu_preds = lr.predict(X_df, delayed=delayed).compute().to_numpy() + accuracy_cuml = accuracy_score(y, cu_preds) sk_preds = sk_model.predict(X) accuracy_sk = accuracy_score(y, sk_preds) + assert len(cu_preds) == len(sk_preds) assert (accuracy_cuml >= accuracy_sk) | ( np.abs(accuracy_cuml - accuracy_sk) < 1e-3 ) @@ -336,3 +354,77 @@ def test_noreg(fit_intercept, client): l1_strength, l2_strength = lr._get_qn_params() assert l1_strength == 0.0 assert l2_strength == 0.0 + + +def test_n_classes_small(client): + def assert_small(X, y, n_classes): + X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1) + from cuml.dask.linear_model import LogisticRegression as cumlLBFGS_dask + + lr = cumlLBFGS_dask() + lr.fit(X_df, y_df) + assert lr._num_classes == n_classes + return lr + + X = np.array([(1, 2), (1, 3)], np.float32) + y = np.array([1.0, 0.0], np.float32) + lr = assert_small(X=X, y=y, n_classes=2) + assert np.array_equal( + lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32) + ) + + X = np.array([(1, 2), (1, 3), (1, 2.5)], np.float32) + y = np.array([1.0, 0.0, 1.0], np.float32) + lr = assert_small(X=X, y=y, n_classes=2) + assert np.array_equal( + lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32) + ) + + X = np.array([(1, 2), (1, 2.5), (1, 3)], np.float32) + y = np.array([1.0, 1.0, 0.0], np.float32) + lr = assert_small(X=X, y=y, n_classes=2) + assert np.array_equal( + lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32) + ) + + X = np.array([(1, 2), (1, 3), (1, 2.5)], np.float32) + y = np.array([10.0, 50.0, 20.0], np.float32) + lr = assert_small(X=X, y=y, n_classes=3) + assert np.array_equal( + lr.classes_.to_numpy(), np.array([10.0, 20.0, 50.0], np.float32) + ) + + +@pytest.mark.parametrize("n_parts", [2, 23]) +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("n_classes", [8]) +def test_n_classes(n_parts, fit_intercept, n_classes, client): + lr = test_lbfgs( + nrows=1e5, + ncols=20, + n_parts=n_parts, + fit_intercept=fit_intercept, + datatype=np.float32, + delayed=True, + client=client, + penalty="l2", + n_classes=n_classes, + ) + + assert lr._num_classes == n_classes + + +@pytest.mark.parametrize("penalty", ["l1", "elasticnet"]) +@pytest.mark.parametrize("l1_ratio", [0.1]) +def test_l1_and_elasticnet(penalty, l1_ratio, client): + X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], np.float32) + y = np.array([1.0, 1.0, 0.0, 0.0], np.float32) + X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1) + + from cuml.dask.linear_model import LogisticRegression + + lr = LogisticRegression(penalty=penalty, l1_ratio=l1_ratio) + with pytest.raises( + RuntimeError, match="Currently only support 'l2' and 'none' penalty" + ): + lr.fit(X_df, y_df) diff --git a/python/pyproject.toml b/python/pyproject.toml index 418c3b18fb..32f1b7a59e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -61,8 +61,8 @@ dependencies = [ "cupy-cuda11x>=12.0.0", "dask-cuda==23.12.*", "dask-cudf==23.12.*", - "dask>=2023.7.1", - "distributed>=2023.7.1", + "dask==2023.9.2", + "distributed==2023.9.2", "joblib>=0.11", "numba>=0.57", "raft-dask==23.12.*",