diff --git a/crates/indexing/src/sealed.rs b/crates/indexing/src/sealed.rs index 67c002e8..1e5f335f 100644 --- a/crates/indexing/src/sealed.rs +++ b/crates/indexing/src/sealed.rs @@ -110,7 +110,9 @@ impl SealedIndexing { Some(QuantizationOptions::Product(_)) => Self::HnswPq(Hnsw::open(path)), Some(QuantizationOptions::Rabitq(_)) => Self::HnswRq(Hnsw::open(path)), }, - IndexingOptions::InvertedIndex(_) => Self::InvertedIndex(SparseInvertedIndex::open(path)), + IndexingOptions::InvertedIndex(_) => { + Self::InvertedIndex(SparseInvertedIndex::open(path)) + } } } diff --git a/crates/quantization/src/scalar.rs b/crates/quantization/src/scalar.rs index cf17bacf..dd9c616c 100644 --- a/crates/quantization/src/scalar.rs +++ b/crates/quantization/src/scalar.rs @@ -18,21 +18,22 @@ use base::search::RerankerPop; use base::search::RerankerPush; use base::search::Vectors; use base::vector::*; -use common::vec2::Vec2; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; use serde::Deserialize; use serde::Serialize; use std::cmp::Reverse; use std::marker::PhantomData; use std::ops::Range; +use stoppable_rayon as rayon; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound = "")] pub struct ScalarQuantizer { dims: u32, bits: u32, - max: Vec, min: Vec, - centroids: Vec2, + max: Vec, _phantom: PhantomData O>, } @@ -40,8 +41,8 @@ impl Quantizer for ScalarQuantizer { fn train( vector_options: VectorOptions, options: Option, - vectors: &impl Vectors, - transform: impl Fn(Borrowed<'_, O>) -> O::Vector + Copy, + vectors: &(impl Vectors + Sync), + transform: impl Fn(Borrowed<'_, O>) -> O::Vector + Copy + Sync, ) -> Self { let options = if let Some(QuantizationOptions::Scalar(x)) = options { x @@ -50,32 +51,46 @@ impl Quantizer for ScalarQuantizer { }; let dims = vector_options.dims; let bits = options.bits; - let mut max = vec![f32::NEG_INFINITY; dims as usize]; - let mut min = vec![f32::INFINITY; dims as usize]; let n = vectors.len(); - for i in 0..n { - let vector = transform(vectors.vector(i)); - let vector = vector.as_borrowed(); - for j in 0..dims { - min[j as usize] = min[j as usize].min(O::get(vector, j).to_f32()); - max[j as usize] = max[j as usize].max(O::get(vector, j).to_f32()); - } - } - let mut centroids = Vec2::zeros((1 << bits, dims as usize)); - for i in 0..dims { - let bas = min[i as usize]; - let del = max[i as usize] - min[i as usize]; - for j in 0_usize..(1 << bits) { - let val = j as f32 / ((1 << bits) - 1) as f32; - centroids[(j, i as usize)] = bas + val * del; - } - } + let (min, max) = (0..n) + .into_par_iter() + .fold( + || { + ( + vec![f32::INFINITY; dims as usize], + vec![f32::NEG_INFINITY; dims as usize], + ) + }, + |(mut min, mut max), i| { + let vector = transform(vectors.vector(i)); + let vector = vector.as_borrowed(); + for j in 0..dims { + min[j as usize] = min[j as usize].min(O::get(vector, j).to_f32()); + max[j as usize] = max[j as usize].max(O::get(vector, j).to_f32()); + } + (min, max) + }, + ) + .reduce( + || { + ( + vec![f32::INFINITY; dims as usize], + vec![f32::NEG_INFINITY; dims as usize], + ) + }, + |(mut min, mut max), (rmin, rmax)| { + for j in 0..dims { + min[j as usize] = min[j as usize].min(rmin[j as usize]); + max[j as usize] = max[j as usize].max(rmax[j as usize]); + } + (min, max) + }, + ); Self { dims, bits, - max, min, - centroids, + max, _phantom: PhantomData, } } @@ -150,7 +165,7 @@ impl Quantizer for ScalarQuantizer { type Lut = Vec; fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { - O::preprocess(self.dims, self.bits, &self.max, &self.min, vector) + O::preprocess(self.dims, self.bits, &self.min, &self.max, vector) } fn process(&self, lut: &Self::Lut, code: &[u8], _: Borrowed<'_, O>) -> Distance { @@ -165,7 +180,7 @@ impl Quantizer for ScalarQuantizer { ); fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut { - O::fscan_preprocess(self.dims, self.bits, &self.max, &self.min, vector) + O::fscan_preprocess(self.dims, self.bits, &self.min, &self.max, vector) } fn fscan_process(&self, flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { @@ -291,8 +306,8 @@ pub trait OperatorScalarQuantization: Operator { fn preprocess( dims: u32, bits: u32, - max: &[f32], min: &[f32], + max: &[f32], vector: Borrowed<'_, Self>, ) -> Vec; fn process(dims: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance; @@ -300,8 +315,8 @@ pub trait OperatorScalarQuantization: Operator { fn fscan_preprocess( dims: u32, bits: u32, - max: &[f32], min: &[f32], + max: &[f32], vector: Borrowed<'_, Self>, ) -> (u32, f32, f32, Vec); fn fscan_process(flut: &(u32, f32, f32, Vec), code: &[u8]) -> [Distance; 32]; @@ -316,22 +331,46 @@ impl OperatorScalarQuantization for VectDot { fn preprocess( dims: u32, bits: u32, - max: &[f32], min: &[f32], + max: &[f32], vector: Borrowed<'_, Self>, ) -> Vec { - let mut xy = Vec::with_capacity(dims as _); - for i in 0..dims { - let bas = min[i as usize]; - let del = max[i as usize] - min[i as usize]; - xy.extend((0..1 << bits).map(|k| { - let x = vector.slice()[i as usize].to_f32(); - let val = k as f32 / ((1 << bits) - 1) as f32; - let y = bas + val * del; - x * y - })); + #[inline(never)] + fn internal( + dims: usize, + min: &[f32], + max: &[f32], + vector: &[S], + ) -> Vec { + assert!(dims <= 65535); + assert!(dims == min.len()); + assert!(dims == max.len()); + assert!(dims == vector.len()); + let mut table = Vec::::with_capacity(dims * (1 << BITS)); + for i in 0..dims { + let bas = min[i]; + let del = (max[i] - min[i]) / ((1 << BITS) - 1) as f32; + for j in 0..1 << BITS { + let x = vector[i].to_f32(); + let y = bas + (j as f32) * del; + let value = x * y; + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } + } + unsafe { + table.set_len(dims * (1 << BITS)); + } + table + } + match bits { + 1 => internal::<1, _>(dims as _, min, max, vector.slice()), + 2 => internal::<2, _>(dims as _, min, max, vector.slice()), + 4 => internal::<4, _>(dims as _, min, max, vector.slice()), + 8 => internal::<8, _>(dims as _, min, max, vector.slice()), + _ => unreachable!(), } - xy } fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance { fn internal(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance { @@ -359,11 +398,11 @@ impl OperatorScalarQuantization for VectDot { fn fscan_preprocess( dims: u32, bits: u32, - max: &[f32], min: &[f32], + max: &[f32], vector: Borrowed<'_, Self>, ) -> (u32, f32, f32, Vec) { - let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, max, min, vector)); + let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, min, max, vector)); (dims, k, b, t) } fn fscan_process(flut: &(u32, f32, f32, Vec), codes: &[u8]) -> [Distance; 32] { @@ -382,23 +421,46 @@ impl OperatorScalarQuantization for VectL2 { fn preprocess( dims: u32, bits: u32, - max: &[f32], min: &[f32], + max: &[f32], vector: Borrowed<'_, Self>, ) -> Vec { - let mut d2 = Vec::with_capacity(dims as _); - for i in 0..dims { - let bas = min[i as usize]; - let del = max[i as usize] - min[i as usize]; - d2.extend((0..1 << bits).map(|k| { - let x = vector.slice()[i as usize].to_f32(); - let val = k as f32 / ((1 << bits) - 1) as f32; - let y = bas + val * del; - let d = x - y; - d * d - })); + #[inline(never)] + fn internal( + dims: usize, + min: &[f32], + max: &[f32], + vector: &[S], + ) -> Vec { + assert!(dims <= 65535); + assert!(dims == min.len()); + assert!(dims == max.len()); + assert!(dims == vector.len()); + let mut table = Vec::::with_capacity(dims * (1 << BITS)); + for i in 0..dims { + let bas = min[i]; + let del = (max[i] - min[i]) / ((1 << BITS) - 1) as f32; + for j in 0..1 << BITS { + let x = vector[i].to_f32(); + let y = bas + (j as f32) * del; + let value = (x - y) * (x - y); + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } + } + unsafe { + table.set_len(dims * (1 << BITS)); + } + table + } + match bits { + 1 => internal::<1, _>(dims as _, min, max, vector.slice()), + 2 => internal::<2, _>(dims as _, min, max, vector.slice()), + 4 => internal::<4, _>(dims as _, min, max, vector.slice()), + 8 => internal::<8, _>(dims as _, min, max, vector.slice()), + _ => unreachable!(), } - d2 } fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance { fn internal(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance { @@ -426,11 +488,11 @@ impl OperatorScalarQuantization for VectL2 { fn fscan_preprocess( dims: u32, bits: u32, - max: &[f32], min: &[f32], + max: &[f32], vector: Borrowed<'_, Self>, ) -> (u32, f32, f32, Vec) { - let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, max, min, vector)); + let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, min, max, vector)); (dims, k, b, t) } fn fscan_process(flut: &(u32, f32, f32, Vec), codes: &[u8]) -> [Distance; 32] {