diff --git a/src/index/ivf_raft/ivf_raft.cuh b/src/index/ivf_raft/ivf_raft.cuh index cea390d1f..c202ca5e5 100644 --- a/src/index/ivf_raft/ivf_raft.cuh +++ b/src/index/ivf_raft/ivf_raft.cuh @@ -40,6 +40,45 @@ namespace knowhere { +__global__ void +filter(const int k1, const int k2, const int nq, const uint8_t* bs, int64_t* ids_before, float* dis_before, + int64_t* ids, float* dis) { + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * blockDim.y + threadIdx.y; + extern __shared__ char s[]; + int64_t* ids_ = (int64_t*)s; + float* dis_ = (float*)&s[k2 * sizeof(int64_t)]; + if (tx >= k2) + return; + int64_t i = ids_before[ty * k2 + tx]; + float d = dis_before[ty * k2 + tx]; + bool check = bs[i >> 3] & (0x1 << (i & 0x7)); + if (!check) { + ids_[tx] = i; + dis_[tx] = d; + } else { + ids_[tx] = -1; + dis_[tx] = -1.0f; + } + __syncthreads(); + if (tx == 0) { + int j = 0, k = 0; + while (j < k1 && k < k2) { + while (ids_[k] == -1) k++; + if (k >= k2) + break; + ids[ty * k1 + j] = ids_[k]; + dis[ty * k1 + j] = dis_[k]; + j++; + k++; + } + if (j != k1) { + ids_before[0] = -1; // destroy answer + } + } + __syncthreads(); +} + namespace raft_res_pool { struct context { @@ -381,9 +420,42 @@ class RaftIvfIndexNode : public IndexNode { if constexpr (std::is_same_v) { auto search_params = raft::neighbors::ivf_flat::search_params{}; search_params.n_probes = ivf_raft_cfg.nprobe; - raft::neighbors::ivf_flat::search(*res_, search_params, *gpu_index_, - data_gpu.data(), rows, ivf_raft_cfg.k, - ids_gpu.data(), dis_gpu.data()); + if (bitset.empty()) { + raft::neighbors::ivf_flat::search(*res_, search_params, *gpu_index_, + data_gpu.data(), rows, ivf_raft_cfg.k, + ids_gpu.data(), dis_gpu.data()); + } else { + auto k1 = ivf_raft_cfg.k; + auto k2 = k1; + k2 |= k2 >> 1; + k2 |= k2 >> 2; + k2 |= k2 >> 4; + k2 |= k2 >> 8; + k2 |= k2 >> 14; + k2 += 1; + while (k2 <= 1024) { + auto ids_gpu_before = rmm::device_uvector(k2 * rows, stream); + auto dis_gpu_before = rmm::device_uvector(k2 * rows, stream); + auto bs_gpu = rmm::device_uvector(bitset.byte_size(), stream); + RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(), + cudaMemcpyDefault, stream.value())); + + raft::neighbors::ivf_flat::search( + *res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(), + dis_gpu_before.data()); + filter<<>>( + k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(), + dis_gpu.data()); + + std::int64_t is_fine = 0; + RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t), + cudaMemcpyDefault, stream.value())); + stream.synchronize(); + if (is_fine != -1) + break; + k2 = k2 << 1; + } + } } else if constexpr (std::is_same_v) { auto search_params = raft::neighbors::ivf_pq::search_params{}; search_params.n_probes = ivf_raft_cfg.nprobe; @@ -411,9 +483,44 @@ class RaftIvfIndexNode : public IndexNode { } search_params.internal_distance_dtype = internal_distance_dtype.value(); search_params.preferred_shmem_carveout = search_params.preferred_shmem_carveout; - raft::neighbors::ivf_pq::search(*res_, search_params, *gpu_index_, data_gpu.data(), - rows, ivf_raft_cfg.k, ids_gpu.data(), - dis_gpu.data()); + if (bitset.empty()) { + raft::neighbors::ivf_pq::search(*res_, search_params, *gpu_index_, + data_gpu.data(), rows, ivf_raft_cfg.k, + ids_gpu.data(), dis_gpu.data()); + } else { + auto k1 = ivf_raft_cfg.k; + auto k2 = k1; + k2 |= k2 >> 1; + k2 |= k2 >> 2; + k2 |= k2 >> 4; + k2 |= k2 >> 8; + k2 |= k2 >> 14; + k2 += 1; + while (k2 <= 1024) { + auto ids_gpu_before = rmm::device_uvector(k2 * rows, stream); + auto dis_gpu_before = rmm::device_uvector(k2 * rows, stream); + auto bs_gpu = rmm::device_uvector(bitset.byte_size(), stream); + RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(), + cudaMemcpyDefault, stream.value())); + + raft::neighbors::ivf_pq::search( + *res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(), + dis_gpu_before.data()); + + filter<<>>( + k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(), + dis_gpu.data()); + + std::int64_t is_fine = 0; + RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t), + cudaMemcpyDefault, stream.value())); + stream.synchronize(); + if (is_fine != -1) + break; + k2 = k2 << 1; + } + } + } else { static_assert(std::is_same_v); }