Skip to content

Commit

Permalink
feat: disable kmeans++ in IVF (#579)
Browse files Browse the repository at this point in the history
* feat: disable kmeans++ in IVF

Signed-off-by: usamoi <[email protected]>

* test: remove flasky test

Signed-off-by: usamoi <[email protected]>

---------

Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Aug 30, 2024
1 parent a71e9ef commit f24bef2
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 51 deletions.
2 changes: 1 addition & 1 deletion crates/ivf/src/ivf_naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ fn from_nothing<O: Op>(
} = options.indexing.clone().unwrap_ivf();
let samples = O::sample(collection, nlist);
rayon::check();
let centroids = k_means(nlist as usize, samples, true, spherical_centroids);
let centroids = k_means(nlist as usize, samples, true, spherical_centroids, false);
rayon::check();
let ls = (0..collection.len())
.into_par_iter()
Expand Down
2 changes: 1 addition & 1 deletion crates/ivf/src/ivf_residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ fn from_nothing<O: Op>(
} = options.indexing.clone().unwrap_ivf();
let samples = O::sample(collection, nlist);
rayon::check();
let centroids = k_means(nlist as usize, samples, true, spherical_centroids);
let centroids = k_means(nlist as usize, samples, true, spherical_centroids, false);
rayon::check();
let ls = (0..collection.len())
.into_par_iter()
Expand Down
14 changes: 7 additions & 7 deletions crates/k_means/src/elkan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct ElkanKMeans<S> {
lowerbound: Square,
upperbound: Vec<f32>,
assign: Vec<usize>,
rand: StdRng,
rng: StdRng,
samples: Vec2<S>,
first: bool,
}
Expand All @@ -25,13 +25,13 @@ impl<S: ScalarLike> ElkanKMeans<S> {
let n = samples.shape_0();
let dims = samples.shape_1();

let mut rand = StdRng::from_entropy();
let mut rng = StdRng::from_entropy();
let mut centroids = Vec2::zeros((c, dims));
let mut lowerbound = Square::new(n, c);
let mut upperbound = vec![0.0f32; n];
let mut assign = vec![0usize; n];

centroids[(0,)].copy_from_slice(&samples[(rand.gen_range(0..n),)]);
centroids[(0,)].copy_from_slice(&samples[(rng.gen_range(0..n),)]);

let mut weight = vec![f32::INFINITY; n];
let mut dis = vec![0.0f32; n];
Expand All @@ -51,10 +51,10 @@ impl<S: ScalarLike> ElkanKMeans<S> {
break;
}
let index = 'a: {
let mut choice = sum * rand.gen_range(0.0..1.0);
let mut choice = sum * rng.gen_range(0.0..1.0);
for j in 0..(n - 1) {
choice -= weight[j];
if choice <= 0.0f32 {
if choice < 0.0f32 {
break 'a j;
}
}
Expand Down Expand Up @@ -85,7 +85,7 @@ impl<S: ScalarLike> ElkanKMeans<S> {
lowerbound,
upperbound,
assign,
rand,
rng,
samples,
first: true,
}
Expand All @@ -95,7 +95,7 @@ impl<S: ScalarLike> ElkanKMeans<S> {
let c = self.c;
let dims = self.dims;
let samples = &self.samples;
let rand = &mut self.rand;
let rand = &mut self.rng;
let assign = &mut self.assign;
let centroids = &mut self.centroids;
let lowerbound = &mut self.lowerbound;
Expand Down
3 changes: 2 additions & 1 deletion crates/k_means/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub fn k_means<S: ScalarLike>(
mut samples: Vec2<S>,
prefer_multithreading: bool,
is_spherical: bool,
prefer_kmeanspp: bool,
) -> Vec2<S> {
assert!(c > 0);
let n = samples.shape_0();
Expand All @@ -38,7 +39,7 @@ pub fn k_means<S: ScalarLike>(
return Vec2::from_vec((c, 1), centroids);
}
if prefer_multithreading {
let mut lloyd_k_means = LloydKMeans::new(c, samples, is_spherical);
let mut lloyd_k_means = LloydKMeans::new(c, samples, is_spherical, prefer_kmeanspp);
for _ in 0..25 {
rayon::check();
if lloyd_k_means.iterate() {
Expand Down
62 changes: 32 additions & 30 deletions crates/k_means/src/lloyd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,50 @@ pub struct LloydKMeans<S> {
is_spherical: bool,
centroids: Vec<Vec<S>>,
assign: Vec<usize>,
rand: StdRng,
rng: StdRng,
samples: Vec2<S>,
}

const DELTA: f32 = f16::EPSILON.to_f32_const();

impl<S: ScalarLike> LloydKMeans<S> {
pub fn new(c: usize, samples: Vec2<S>, is_spherical: bool) -> Self {
pub fn new(c: usize, samples: Vec2<S>, is_spherical: bool, prefer_kmeanspp: bool) -> Self {
let n = samples.shape_0();
let dims = samples.shape_1();

let mut rand = StdRng::from_entropy();
let mut rng = StdRng::from_entropy();
let mut centroids = Vec::with_capacity(c);

centroids.push(samples[(rand.gen_range(0..n),)].to_vec());

let mut weight = vec![f32::INFINITY; n];
for i in 0..c {
let dis_2 = (0..n)
.into_par_iter()
.map(|j| S::reduce_sum_of_d2(&samples[(j,)], &centroids[i]))
.collect::<Vec<_>>();
for j in 0..n {
if dis_2[j] < weight[j] {
weight[j] = dis_2[j];
if prefer_kmeanspp {
centroids.push(samples[(rng.gen_range(0..n),)].to_vec());
let mut weight = vec![f32::INFINITY; n];
for i in 1..c {
let dis_2 = (0..n)
.into_par_iter()
.map(|j| S::reduce_sum_of_d2(&samples[(j,)], &centroids[i - 1]))
.collect::<Vec<_>>();
for j in 0..n {
if dis_2[j] < weight[j] {
weight[j] = dis_2[j];
}
}
let sum = f32::reduce_sum_of_x(&weight);
let index = 'a: {
let mut choice = sum * rng.gen_range(0.0..1.0);
for j in 0..(n - 1) {
choice -= weight[j];
if choice < 0.0f32 {
break 'a j;
}
}
n - 1
};
centroids.push(samples[(index,)].to_vec());
}
let sum = f32::reduce_sum_of_x(&weight);
if i + 1 == c {
break;
} else {
for index in rand::seq::index::sample(&mut rng, n, c).into_iter() {
centroids.push(samples[(index,)].to_vec());
}
let index = 'a: {
let mut choice = sum * rand.gen_range(0.0..1.0);
for j in 0..(n - 1) {
choice -= weight[j];
if choice <= 0.0f32 {
break 'a j;
}
}
n - 1
};
centroids.push(samples[(index,)].to_vec());
}

let assign = (0..n)
Expand All @@ -77,15 +79,15 @@ impl<S: ScalarLike> LloydKMeans<S> {
is_spherical,
centroids,
assign,
rand,
rng,
samples,
}
}

pub fn iterate(&mut self) -> bool {
let dims = self.dims;
let c = self.c;
let rand = &mut self.rand;
let rand = &mut self.rng;
let samples = &self.samples;
let n = samples.shape_0();

Expand Down
7 changes: 3 additions & 4 deletions crates/k_means/src/quick_centers.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use base::scalar::ScalarLike;
use common::vec2::Vec2;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand::Rng;

pub fn quick_centers<S: ScalarLike>(c: usize, samples: Vec2<S>) -> Vec2<S> {
let n = samples.shape_0();
let dims = samples.shape_1();
assert!(c >= n);
let mut rand = StdRng::from_entropy();
let mut rng = rand::thread_rng();
let mut centroids = Vec2::zeros((c, dims));
centroids
.as_mut_slice()
.fill_with(|| S::from_f32(rand.gen_range(0.0..1.0f32)));
.fill_with(|| S::from_f32(rng.gen_range(0.0..1.0f32)));
for i in 0..n {
centroids[(i,)].copy_from_slice(&samples[(i,)]);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/quantization/src/product/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl<O: OperatorProductQuantization> ProductQuantizer<O> {
)
.to_vec()
});
k_means(1 << bits, subsamples, false, false)
k_means(1 << bits, subsamples, false, false, true)
})
.collect::<Vec<_>>();
let mut centroids = Vec2::zeros((1 << bits, dims as usize));
Expand Down
2 changes: 1 addition & 1 deletion crates/rabitq/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fn from_nothing<O: Op>(
rayon::check();
let samples = O::sample(collection, nlist);
rayon::check();
let centroids: Vec2<f32> = k_means(nlist as usize, samples, true, spherical_centroids);
let centroids: Vec2<f32> = k_means(nlist as usize, samples, true, spherical_centroids, false);
rayon::check();
let ls = (0..collection.len())
.into_par_iter()
Expand Down
5 changes: 0 additions & 5 deletions tests/sqllogictest/bvector.slt
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0,1,0,1,0,1,0,1,0,1]'::
----
10

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0,1,0,1,0,1,0,1,0,1]'::bvector limit 10) t2;
----
10

statement ok
DROP TABLE t;

Expand Down

0 comments on commit f24bef2

Please sign in to comment.