Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

raft index support bitset filter #850

Merged
merged 1 commit into from
May 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 113 additions & 6 deletions src/index/ivf_raft/ivf_raft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,45 @@

namespace knowhere {

__global__ void
filter(const int k1, const int k2, const int nq, const uint8_t* bs, int64_t* ids_before, float* dis_before,
int64_t* ids, float* dis) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int ty = blockIdx.y * blockDim.y + threadIdx.y;
extern __shared__ char s[];
int64_t* ids_ = (int64_t*)s;
float* dis_ = (float*)&s[k2 * sizeof(int64_t)];
if (tx >= k2)
return;
int64_t i = ids_before[ty * k2 + tx];
float d = dis_before[ty * k2 + tx];
bool check = bs[i >> 3] & (0x1 << (i & 0x7));
if (!check) {
ids_[tx] = i;
dis_[tx] = d;
} else {
ids_[tx] = -1;
dis_[tx] = -1.0f;
}
__syncthreads();
if (tx == 0) {
int j = 0, k = 0;
while (j < k1 && k < k2) {
while (ids_[k] == -1) k++;
if (k >= k2)
Presburger marked this conversation as resolved.
Show resolved Hide resolved
break;
ids[ty * k1 + j] = ids_[k];
dis[ty * k1 + j] = dis_[k];
j++;
k++;
}
if (j != k1) {
ids_before[0] = -1; // destroy answer
}
}
__syncthreads();
}

namespace raft_res_pool {

struct context {
Expand Down Expand Up @@ -381,9 +420,42 @@ class RaftIvfIndexNode : public IndexNode {
if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto search_params = raft::neighbors::ivf_flat::search_params{};
search_params.n_probes = ivf_raft_cfg.nprobe;
raft::neighbors::ivf_flat::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
if (bitset.empty()) {
raft::neighbors::ivf_flat::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
} else {
auto k1 = ivf_raft_cfg.k;
auto k2 = k1;
k2 |= k2 >> 1;
k2 |= k2 >> 2;
Presburger marked this conversation as resolved.
Show resolved Hide resolved
k2 |= k2 >> 4;
k2 |= k2 >> 8;
k2 |= k2 >> 14;
k2 += 1;
while (k2 <= 1024) {
auto ids_gpu_before = rmm::device_uvector<std::int64_t>(k2 * rows, stream);
auto dis_gpu_before = rmm::device_uvector<float>(k2 * rows, stream);
auto bs_gpu = rmm::device_uvector<uint8_t>(bitset.byte_size(), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(),
cudaMemcpyDefault, stream.value()));

raft::neighbors::ivf_flat::search<float, std::int64_t>(
*res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(),
dis_gpu_before.data());
filter<<<dim3(1, rows), k2, k2 * sizeof(std::int64_t) + k2 * sizeof(float), stream.value()>>>(
k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(),
dis_gpu.data());

std::int64_t is_fine = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
stream.synchronize();
if (is_fine != -1)
break;
k2 = k2 << 1;
Presburger marked this conversation as resolved.
Show resolved Hide resolved
}
}
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
auto search_params = raft::neighbors::ivf_pq::search_params{};
search_params.n_probes = ivf_raft_cfg.nprobe;
Expand Down Expand Up @@ -411,9 +483,44 @@ class RaftIvfIndexNode : public IndexNode {
}
search_params.internal_distance_dtype = internal_distance_dtype.value();
search_params.preferred_shmem_carveout = search_params.preferred_shmem_carveout;
raft::neighbors::ivf_pq::search<float, std::int64_t>(*res_, search_params, *gpu_index_, data_gpu.data(),
rows, ivf_raft_cfg.k, ids_gpu.data(),
dis_gpu.data());
if (bitset.empty()) {
raft::neighbors::ivf_pq::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
} else {
auto k1 = ivf_raft_cfg.k;
auto k2 = k1;
k2 |= k2 >> 1;
k2 |= k2 >> 2;
k2 |= k2 >> 4;
k2 |= k2 >> 8;
k2 |= k2 >> 14;
k2 += 1;
while (k2 <= 1024) {
auto ids_gpu_before = rmm::device_uvector<std::int64_t>(k2 * rows, stream);
auto dis_gpu_before = rmm::device_uvector<float>(k2 * rows, stream);
auto bs_gpu = rmm::device_uvector<uint8_t>(bitset.byte_size(), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(),
cudaMemcpyDefault, stream.value()));

raft::neighbors::ivf_pq::search<float, std::int64_t>(
*res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(),
dis_gpu_before.data());

filter<<<dim3(1, rows), k2, k2 * sizeof(std::int64_t) + k2 * sizeof(float), stream.value()>>>(
k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(),
dis_gpu.data());

std::int64_t is_fine = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
stream.synchronize();
if (is_fine != -1)
break;
k2 = k2 << 1;
}
}

} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
Expand Down