Skip to content

Commit

Permalink
Resolve PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Nov 1, 2024
1 parent 2f317b8 commit 428ead4
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 179 deletions.
198 changes: 91 additions & 107 deletions cpp/include/raft/sparse/solver/detail/lanczos.cuh

Large diffs are not rendered by default.

70 changes: 40 additions & 30 deletions cpp/include/raft/sparse/solver/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,63 +28,73 @@ namespace raft::sparse::solver {
// Eigensolver
// =========================================================

/**
* @brief Find the smallest eigenpairs using lanczos solver
* @tparam index_type_t the type of data used for indexing.
* @tparam value_type_t the type of data used for weights, distances.
* @param handle the raft handle.
* @param A Sparse matrix in CSR format.
* @param config lanczos config used to set hyperparameters
* @param v0 Initial lanczos vector
* @param eigenvalues output eigenvalues
* @param eigenvectors output eigenvectors
* @return Zero if successful. Otherwise non-zero.
*/
template <typename IndexTypeT, typename ValueTypeT>
auto lanczos_compute_smallest_eigenvectors(
raft::resources const& handle,
raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> rows,
raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> cols,
raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> vals,
lanczos_solver_config<IndexTypeT, ValueTypeT> const& config,
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
lanczos_solver_config<ValueTypeT> const& config,
raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> v0,
raft::device_vector_view<ValueTypeT, uint32_t, raft::col_major> eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
{
IndexTypeT ncols = rows.extent(0) - 1;
IndexTypeT nrows = rows.extent(0) - 1;
IndexTypeT nnz = cols.extent(0);

auto csr_structure =
raft::make_device_compressed_structure_view<IndexTypeT, IndexTypeT, IndexTypeT>(
const_cast<IndexTypeT*>(rows.data_handle()),
const_cast<IndexTypeT*>(cols.data_handle()),
ncols,
nrows,
nnz);

auto csr_matrix =
raft::make_device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT>(
const_cast<ValueTypeT*>(vals.data_handle()), csr_structure);

return detail::lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
handle, csr_matrix, config, v0, eigenvalues, eigenvectors);
handle, A, config, v0, eigenvalues, eigenvectors);
}

/**
* @brief Find the smallest eigenpairs using lanczos solver
* @tparam index_type_t the type of data used for indexing.
* @tparam value_type_t the type of data used for weights, distances.
* @param handle the raft handle.
* @param rows Vector view of the rows of the sparse matrix.
* @param cols Vector view of the cols of the sparse matrix.
* @param vals Vector view of the vals of the sparse matrix.
* @param config lanczos config used to set hyperparameters
* @param v0 Initial lanczos vector
* @param eigenvalues output eigenvalues
* @param eigenvectors output eigenvectors
* @return Zero if successful. Otherwise non-zero.
*/
template <typename IndexTypeT, typename ValueTypeT>
auto lanczos_compute_smallest_eigenvectors(
raft::resources const& handle,
raft::spectral::matrix::sparse_matrix_t<IndexTypeT, ValueTypeT> const& A,
lanczos_solver_config<IndexTypeT, ValueTypeT> const& config,
raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> rows,
raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> cols,
raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> vals,
lanczos_solver_config<ValueTypeT> const& config,
raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> v0,
raft::device_vector_view<ValueTypeT, uint32_t, raft::col_major> eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
{
IndexTypeT ncols = A.ncols_;
IndexTypeT nrows = A.nrows_;
IndexTypeT nnz = A.nnz_;
IndexTypeT ncols = rows.extent(0) - 1;
IndexTypeT nrows = rows.extent(0) - 1;
IndexTypeT nnz = cols.extent(0);

auto csr_structure =
raft::make_device_compressed_structure_view<IndexTypeT, IndexTypeT, IndexTypeT>(
const_cast<IndexTypeT*>(A.row_offsets_),
const_cast<IndexTypeT*>(A.col_indices_),
const_cast<IndexTypeT*>(rows.data_handle()),
const_cast<IndexTypeT*>(cols.data_handle()),
ncols,
nrows,
nnz);

auto csr_matrix =
raft::make_device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT>(
const_cast<ValueTypeT*>(A.values_), csr_structure);
const_cast<ValueTypeT*>(vals.data_handle()), csr_structure);

return detail::lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
return lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
handle, csr_matrix, config, v0, eigenvalues, eigenvectors);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/solver/lanczos_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace raft::sparse::solver {

template <typename IndexTypeT, typename ValueTypeT>
template <typename ValueTypeT>
struct lanczos_solver_config {
int n_components;
int max_iterations;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft_runtime/solver/lanczos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace raft::runtime::solver {
raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals, \
raft::sparse::solver::lanczos_solver_config<IndexType, ValueType> config, \
raft::sparse::solver::lanczos_solver_config<ValueType> config, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> v0, \
raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues, \
raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors)
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/solver/lanczos_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals, \
raft::sparse::solver::lanczos_solver_config<IndexType, ValueType> config, \
raft::sparse::solver::lanczos_solver_config<ValueType> config, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> v0, \
raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues, \
raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors) \
Expand Down
33 changes: 27 additions & 6 deletions cpp/test/sparse/solver/lanczos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,23 @@ class rmat_lanczos_tests
symmetric_coo.vals(),
symmetric_coo.n_rows,
symmetric_coo.nnz};
raft::sparse::solver::lanczos_solver_config<IndexType, ValueType> config{
raft::sparse::solver::lanczos_solver_config<ValueType> config{
n_components, params.maxiter, params.restartiter, params.tol, rng.seed};

auto csr_structure =
raft::make_device_compressed_structure_view<IndexType, IndexType, IndexType>(
const_cast<IndexType*>(row_indices.data_handle()),
const_cast<IndexType*>(symmetric_coo.cols()),
symmetric_coo.n_rows,
symmetric_coo.n_rows,
symmetric_coo.nnz);

auto csr_matrix = raft::make_device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType>(
const_cast<ValueType*>(symmetric_coo.vals()), csr_structure);

std::get<0>(stats) =
raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>(
handle, csr_m, config, v0.view(), eigenvalues.view(), eigenvectors.view());
handle, csr_matrix, config, v0.view(), eigenvalues.view(), eigenvectors.view());

ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
expected_eigenvalues.data_handle(),
Expand Down Expand Up @@ -251,13 +263,22 @@ class lanczos_tests : public ::testing::TestWithParam<lanczos_inputs<IndexType,
raft::random::uniform<ValueType>(handle, rng, v0.view(), 0, 1);
std::tuple<IndexType, ValueType, IndexType> stats;

raft::spectral::matrix::sparse_matrix_t<IndexType, ValueType> const csr_m{
handle, rows.data_handle(), cols.data_handle(), vals.data_handle(), n, nnz};
raft::sparse::solver::lanczos_solver_config<IndexType, ValueType> config{
raft::sparse::solver::lanczos_solver_config<ValueType> config{
params.n_components, params.maxiter, params.restartiter, params.tol, rng.seed};
auto csr_structure =
raft::make_device_compressed_structure_view<IndexType, IndexType, IndexType>(
const_cast<IndexType*>(rows.data_handle()),
const_cast<IndexType*>(cols.data_handle()),
n,
n,
nnz);

auto csr_matrix = raft::make_device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType>(
const_cast<ValueType*>(vals.data_handle()), csr_structure);

std::get<0>(stats) =
raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>(
handle, csr_m, config, v0.view(), eigenvalues.view(), eigenvectors.view());
handle, csr_matrix, config, v0.view(), eigenvalues.view(), eigenvectors.view());

ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
expected_eigenvalues.data_handle(),
Expand Down
64 changes: 31 additions & 33 deletions python/pylibraft/pylibraft/solver/lanczos.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,15 @@ from pylibraft.random.cpp.rng_state cimport RngState
cdef extern from "raft/sparse/solver/lanczos_types.hpp" \
namespace "raft::sparse::solver" nogil:

cdef cppclass lanczos_solver_config[IndexTypeT, ValueTypeT]:
cdef cppclass lanczos_solver_config[ValueTypeT]:
int n_components
int max_iterations
int ncv
ValueTypeT tolerance
uint64_t seed

cdef lanczos_solver_config[int, float] config_int_float
cdef lanczos_solver_config[int64_t, float] config_int64_float
cdef lanczos_solver_config[int, double] config_int_double
cdef lanczos_solver_config[int64_t, double] config_int64_double
cdef lanczos_solver_config[float] config_float
cdef lanczos_solver_config[double] config_double

cdef extern from "raft_runtime/solver/lanczos.hpp" \
namespace "raft::runtime::solver" nogil:
Expand All @@ -64,7 +62,7 @@ cdef extern from "raft_runtime/solver/lanczos.hpp" \
device_vector_view[int64_t, uint32_t] rows,
device_vector_view[int64_t, uint32_t] cols,
device_vector_view[double, uint32_t] vals,
lanczos_solver_config[int64_t, double] config,
lanczos_solver_config[double] config,
device_vector_view[double, uint32_t] v0,
device_vector_view[double, uint32_t] eigenvalues,
device_matrix_view[double, uint32_t, col_major] eigenvectors) except +
Expand All @@ -74,7 +72,7 @@ cdef extern from "raft_runtime/solver/lanczos.hpp" \
device_vector_view[int64_t, uint32_t] rows,
device_vector_view[int64_t, uint32_t] cols,
device_vector_view[float, uint32_t] vals,
lanczos_solver_config[int64_t, float] config,
lanczos_solver_config[float] config,
device_vector_view[float, uint32_t] v0,
device_vector_view[float, uint32_t] eigenvalues,
device_matrix_view[float, uint32_t, col_major] eigenvectors) except +
Expand All @@ -84,7 +82,7 @@ cdef extern from "raft_runtime/solver/lanczos.hpp" \
device_vector_view[int, uint32_t] rows,
device_vector_view[int, uint32_t] cols,
device_vector_view[double, uint32_t] vals,
lanczos_solver_config[int, double] config,
lanczos_solver_config[double] config,
device_vector_view[double, uint32_t] v0,
device_vector_view[double, uint32_t] eigenvalues,
device_matrix_view[double, uint32_t, col_major] eigenvectors) except +
Expand All @@ -94,7 +92,7 @@ cdef extern from "raft_runtime/solver/lanczos.hpp" \
device_vector_view[int, uint32_t] rows,
device_vector_view[int, uint32_t] cols,
device_vector_view[float, uint32_t] vals,
lanczos_solver_config[int, float] config,
lanczos_solver_config[float] config,
device_vector_view[float, uint32_t] v0,
device_vector_view[float, uint32_t] eigenvalues,
device_matrix_view[float, uint32_t, col_major] eigenvectors) except +
Expand Down Expand Up @@ -193,68 +191,68 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
cdef device_resources *h = <device_resources*><size_t>handle.getHandle()

if IndexType == np.int32 and ValueType == np.float32:
config_int_float.n_components = k
config_int_float.max_iterations = maxiter
config_int_float.ncv = ncv
config_int_float.tolerance = tol
config_int_float.seed = seed
config_float.n_components = k
config_float.max_iterations = maxiter
config_float.ncv = ncv
config_float.tolerance = tol
config_float.seed = seed
lanczos_solver(
deref(h),
make_device_vector_view(<int *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<float *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[int, float]> config_int_float,
<lanczos_solver_config[float]> config_float,
make_device_vector_view(<float *>v0_ptr, <uint32_t> N),
make_device_vector_view(<float *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[float, uint32_t, col_major](
<float *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
)
elif IndexType == np.int64 and ValueType == np.float32:
config_int64_float.n_components = k
config_int64_float.max_iterations = maxiter
config_int64_float.ncv = ncv
config_int64_float.tolerance = tol
config_int64_float.seed = seed
config_float.n_components = k
config_float.max_iterations = maxiter
config_float.ncv = ncv
config_float.tolerance = tol
config_float.seed = seed
lanczos_solver(
deref(h),
make_device_vector_view(<int64_t *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int64_t *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<float *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[int64_t, float]> config_int64_float,
<lanczos_solver_config[float]> config_float,
make_device_vector_view(<float *>v0_ptr, <uint32_t> N),
make_device_vector_view(<float *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[float, uint32_t, col_major](
<float *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
)
elif IndexType == np.int32 and ValueType == np.float64:
config_int_double.n_components = k
config_int_double.max_iterations = maxiter
config_int_double.ncv = ncv
config_int_double.tolerance = tol
config_int_double.seed = seed
config_double.n_components = k
config_double.max_iterations = maxiter
config_double.ncv = ncv
config_double.tolerance = tol
config_double.seed = seed
lanczos_solver(
deref(h),
make_device_vector_view(<int *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<double *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[int, double]> config_int_double,
<lanczos_solver_config[double]> config_double,
make_device_vector_view(<double *>v0_ptr, <uint32_t> N),
make_device_vector_view(<double *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[double, uint32_t, col_major](
<double *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
)
elif IndexType == np.int64 and ValueType == np.float64:
config_int64_double.n_components = k
config_int64_double.max_iterations = maxiter
config_int64_double.ncv = ncv
config_int64_double.tolerance = tol
config_int64_double.seed = seed
config_double.n_components = k
config_double.max_iterations = maxiter
config_double.ncv = ncv
config_double.tolerance = tol
config_double.seed = seed
lanczos_solver(
deref(h),
make_device_vector_view(<int64_t *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int64_t *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<double *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[int64_t, double]> config_int64_double,
<lanczos_solver_config[double]> config_double,
make_device_vector_view(<double *>v0_ptr, <uint32_t> N),
make_device_vector_view(<double *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[double, uint32_t, col_major](
Expand Down

0 comments on commit 428ead4

Please sign in to comment.