From 60ac98bc29e76f8ef06cddcc339382c986728d3e Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 28 Aug 2024 19:21:18 +0800 Subject: [PATCH] [experimental] use centroids generated by faiss Signed-off-by: usamoi --- crates/rabitq/src/lib.rs | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index 54271ffd6..90e73a2b3 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -17,7 +17,7 @@ use common::json::Json; use common::mmap_array::MmapArray; use common::remap::RemappedCollection; use common::vec2::Vec2; -use k_means::{k_means, k_means_lookup, k_means_lookup_many}; +use k_means::{k_means_lookup, k_means_lookup_many}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::fs::create_dir; use std::path::Path; @@ -108,7 +108,7 @@ fn from_nothing( create_dir(path.as_ref()).unwrap(); let RabitqIndexingOptions { nlist, - spherical_centroids, + spherical_centroids: _, } = options.indexing.clone().unwrap_rabitq(); let projection = { use nalgebra::{DMatrix, QR}; @@ -130,9 +130,39 @@ fn from_nothing( } projection }; - let samples = O::sample(collection); rayon::check(); - let centroids: Vec2 = k_means(nlist as usize, samples, spherical_centroids); + // let centroids: Vec2 = k_means(nlist as usize, samples, spherical_centroids); + let centroids: Vec2 = { + fn read_vecs(path: impl AsRef) -> std::io::Result>> { + use std::io::Read; + + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + let mut buf = [0u8; 4]; + let mut vecs = Vec::new(); + loop { + let count = reader.read(&mut buf)?; + if count == 0 { + break; + } + let dim = u32::from_le_bytes(buf) as usize; + let mut vec = Vec::with_capacity(dim); + for _ in 0..dim { + reader.read_exact(&mut buf)?; + vec.push(f32::from_le_bytes(buf)); + } + vecs.push(vec); + } + Ok(vecs) + } + fn load_centroids_from_fvecs(path: impl AsRef) -> Vec2 { + let fvecs = read_vecs(&path).expect("read centroids error"); + let nlist = fvecs.len(); + let dims = fvecs[0].len(); + Vec2::from_vec((nlist, dims), fvecs.into_iter().flatten().collect()) + } + load_centroids_from_fvecs("/usamoi/repos/RaBitQ/gist/gist_centroid_4096.fvecs") + }; rayon::check(); let ls = (0..collection.len()) .into_par_iter()