Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
optimize gpu code and add cagra impl (#942)
Browse files Browse the repository at this point in the history
Signed-off-by: Yusheng.Ma <[email protected]>
  • Loading branch information
Presburger authored Jun 20, 2023
1 parent 374c43f commit 7a292fb
Show file tree
Hide file tree
Showing 8 changed files with 462 additions and 122 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ if(WITH_COVERAGE)
endif()

knowhere_file_glob(GLOB_RECURSE KNOWHERE_SRCS src/common/*.cc src/index/*.cc
src/io/*.cc src/index/*.cu)
src/io/*.cc src/index/*.cu src/common/raft/*.cu)

set(KNOWHERE_LINKER_LIBS "")

Expand All @@ -109,7 +109,8 @@ list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_GPU_SRCS})

if(NOT WITH_RAFT)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS src/index/ivf_raft/*.cc
src/index/ivf_raft/*.cu)
src/index/ivf_raft/*.cu src/index/cagra/*.cu
src/common/raft/*.cu)
list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_RAFT_SRCS})
endif()

Expand Down
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ constexpr const char* INDEX_FAISS_GPU_IVFSQ8 = "GPU_FAISS_IVF_SQ8";

constexpr const char* INDEX_RAFT_IVFFLAT = "GPU_RAFT_IVF_FLAT";
constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ";
constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA";

constexpr const char* INDEX_HNSW = "HNSW";
constexpr const char* INDEX_DISKANN = "DISKANN";
Expand Down
39 changes: 39 additions & 0 deletions src/common/raft/raft.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "../src/distance/specializations/fused_l2_nn_double_int.cu"
#include "../src/distance/specializations/fused_l2_nn_double_int64.cu"
#include "../src/distance/specializations/fused_l2_nn_float_int.cu"
#include "../src/distance/specializations/fused_l2_nn_float_int64.cu"
#include "../src/matrix/specializations/detail/select_k_float_int64_t.cu"
#include "../src/matrix/specializations/detail/select_k_float_uint32_t.cu"
#include "../src/matrix/specializations/detail/select_k_half_int64_t.cu"
#include "../src/matrix/specializations/detail/select_k_half_uint32_t.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu"
#include "../src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu"
#include "../src/neighbors/specializations/fused_l2_knn_int_float_false.cu"
#include "../src/neighbors/specializations/fused_l2_knn_int_float_true.cu"
#include "../src/neighbors/specializations/fused_l2_knn_long_float_false.cu"
#include "../src/neighbors/specializations/fused_l2_knn_long_float_true.cu"
#include "../src/neighbors/specializations/ivfflat_build_float_int64_t.cu"
#include "../src/neighbors/specializations/ivfflat_extend_float_int64_t.cu"
#include "../src/neighbors/specializations/ivfflat_search_float_int64_t.cu"
#include "../src/neighbors/specializations/ivfpq_build_float_int64_t.cu"
#include "../src/neighbors/specializations/ivfpq_extend_float_int64_t.cu"
#include "../src/neighbors/specializations/ivfpq_search_float_int64_t.cu"
42 changes: 42 additions & 0 deletions src/common/raft/res_pool.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "res_pool.cuh"
namespace raft_res_pool {

resource&
resource::instance() {
static resource res;
return res;
}

void
resource::set_pool_size(std::size_t init_size, std::size_t max_size) {
this->initial_pool_size = init_size;
this->maximum_pool_size = max_size;
}

void
resource::init(rmm::cuda_device_id device_id) {
std::lock_guard<std::mutex> lock(mtx_);
auto it = map_.find(device_id.value());
if (it == map_.end()) {
char* env_str = getenv("KNOWHERE_GPU_MEM_POOL_SIZE");
if (env_str != NULL) {
std::size_t initial_pool_size_tmp, maximum_pool_size_tmp;
auto stat = sscanf(env_str, "%zu;%zu", &initial_pool_size_tmp, &maximum_pool_size_tmp);
if (stat == 2) {
LOG_KNOWHERE_INFO_ << "Get Gpu Pool Size From env, init size: " << initial_pool_size_tmp
<< " MB, max size: " << maximum_pool_size_tmp << " MB";
this->initial_pool_size = initial_pool_size_tmp;
this->maximum_pool_size = maximum_pool_size_tmp;
} else {
LOG_KNOWHERE_WARNING_ << "please check env format";
}
}

auto mr_ = std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
&up_mr_, initial_pool_size << 20, maximum_pool_size << 20);
rmm::mr::set_per_device_resource(device_id, mr_.get());
map_[device_id.value()] = std::move(mr_);
}
}

}; // namespace raft_res_pool
59 changes: 59 additions & 0 deletions src/common/raft/res_pool.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "knowhere/log.h"
#include "raft/core/device_resources.hpp"

namespace raft_res_pool {

struct context {
context()
: resources_(
[]() {
return new rmm::cuda_stream(); // Avoid program exit datart
// unload error
}()
->view(),
nullptr, rmm::mr::get_current_device_resource()) {
}
~context() = default;
context(context&&) = delete;
context(context const&) = delete;
context&
operator=(context&&) = delete;
context&
operator=(context const&) = delete;
raft::device_resources resources_;
};

inline context&
get_context() {
thread_local context ctx;
return ctx;
};
class resource {
public:
static resource&
instance();
void
set_pool_size(std::size_t init_size, std::size_t max_size);

void
init(rmm::cuda_device_id device_id);

private:
resource(){};
~resource(){};
resource(resource&&) = delete;
resource(resource const&) = delete;
resource&
operator=(resource&&) = delete;
resource&
operator=(context const&) = delete;
rmm::mr::cuda_memory_resource up_mr_;
std::map<rmm::cuda_device_id::value_type,
std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>>
map_;
mutable std::mutex mtx_;
std::size_t initial_pool_size = 2048; // MB
std::size_t maximum_pool_size = 4096; // MB
};

}; // namespace raft_res_pool
Loading

0 comments on commit 7a292fb

Please sign in to comment.