diff --git a/crates/base/src/distance.rs b/crates/base/src/distance.rs index ed9c8fd75..0dc7bba73 100644 --- a/crates/base/src/distance.rs +++ b/crates/base/src/distance.rs @@ -8,3 +8,70 @@ pub enum DistanceKind { Hamming, Jaccard, } + +#[derive( + Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash, +)] +#[repr(transparent)] +#[serde(transparent)] +pub struct Distance(i32); + +impl Distance { + pub const ZERO: Self = Self(0); + pub const INFINITY: Self = Self(2139095040); + pub const NEG_INFINITY: Self = Self(-2139095041); + + pub fn to_f32(self) -> f32 { + self.into() + } +} + +impl From for Distance { + #[inline(always)] + fn from(value: f32) -> Self { + let bits = value.to_bits() as i32; + Self(bits ^ (((bits >> 31) as u32) >> 1) as i32) + } +} + +impl From for f32 { + #[inline(always)] + fn from(Distance(bits): Distance) -> Self { + f32::from_bits((bits ^ (((bits >> 31) as u32) >> 1) as i32) as u32) + } +} + +#[test] +fn distance_conversions() { + assert_eq!(Distance::from(0.0f32), Distance::ZERO); + assert_eq!(Distance::from(f32::INFINITY), Distance::INFINITY); + assert_eq!(Distance::from(f32::NEG_INFINITY), Distance::NEG_INFINITY); + for i in -100..100 { + let val = (i as f32) * 0.1; + assert_eq!(f32::from(Distance::from(val)).to_bits(), val.to_bits()); + } + assert_eq!( + f32::from(Distance::from(0.0f32)).to_bits(), + 0.0f32.to_bits() + ); + assert_eq!( + f32::from(Distance::from(-0.0f32)).to_bits(), + (-0.0f32).to_bits() + ); + assert_eq!( + f32::from(Distance::from(f32::NAN)).to_bits(), + f32::NAN.to_bits() + ); + assert_eq!( + f32::from(Distance::from(-f32::NAN)).to_bits(), + (-f32::NAN).to_bits() + ); + assert_eq!( + f32::from(Distance::from(f32::INFINITY)).to_bits(), + f32::INFINITY.to_bits() + ); + assert_eq!( + f32::from(Distance::from(-f32::INFINITY)).to_bits(), + (-f32::INFINITY).to_bits() + ); +} diff --git a/crates/base/src/operator/bvector_dot.rs b/crates/base/src/operator/bvector_dot.rs index 12e07ce92..16571839c 100644 --- a/crates/base/src/operator/bvector_dot.rs +++ b/crates/base/src/operator/bvector_dot.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for BVectorDot { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; - fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance { lhs.operator_dot(rhs) } } diff --git a/crates/base/src/operator/bvector_hamming.rs b/crates/base/src/operator/bvector_hamming.rs index 3e2514bb2..fce912f3f 100644 --- a/crates/base/src/operator/bvector_hamming.rs +++ b/crates/base/src/operator/bvector_hamming.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for BVectorHamming { const DISTANCE_KIND: DistanceKind = DistanceKind::Hamming; - fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance { lhs.operator_hamming(rhs) } } diff --git a/crates/base/src/operator/bvector_jaccard.rs b/crates/base/src/operator/bvector_jaccard.rs index d2387eb73..71e445395 100644 --- a/crates/base/src/operator/bvector_jaccard.rs +++ b/crates/base/src/operator/bvector_jaccard.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for BVectorJaccard { const DISTANCE_KIND: DistanceKind = DistanceKind::Jaccard; - fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance { lhs.operator_jaccard(rhs) } } diff --git a/crates/base/src/operator/mod.rs b/crates/base/src/operator/mod.rs index b12bacd4e..6a5f954d1 100644 --- a/crates/base/src/operator/mod.rs +++ b/crates/base/src/operator/mod.rs @@ -19,7 +19,6 @@ pub use vecf32_dot::Vecf32Dot; pub use vecf32_l2::Vecf32L2; use crate::distance::*; -use crate::scalar::*; use crate::vector::*; pub trait Operator: Copy + 'static + Send + Sync { @@ -27,7 +26,7 @@ pub trait Operator: Copy + 'static + Send + Sync { const DISTANCE_KIND: DistanceKind; - fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32; + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance; } pub type Owned = ::VectorOwned; diff --git a/crates/base/src/operator/svecf32_dot.rs b/crates/base/src/operator/svecf32_dot.rs index 518cb02cc..c8d60fdad 100644 --- a/crates/base/src/operator/svecf32_dot.rs +++ b/crates/base/src/operator/svecf32_dot.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for SVecf32Dot { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; - fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance { lhs.operator_dot(rhs) } } diff --git a/crates/base/src/operator/svecf32_l2.rs b/crates/base/src/operator/svecf32_l2.rs index c40093a6a..1c4d6f9fb 100644 --- a/crates/base/src/operator/svecf32_l2.rs +++ b/crates/base/src/operator/svecf32_l2.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for SVecf32L2 { const DISTANCE_KIND: DistanceKind = DistanceKind::L2; - fn distance(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { + fn distance(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> Distance { lhs.operator_l2(rhs) } } diff --git a/crates/base/src/operator/vecf16_dot.rs b/crates/base/src/operator/vecf16_dot.rs index a82536d91..3446d6b1c 100644 --- a/crates/base/src/operator/vecf16_dot.rs +++ b/crates/base/src/operator/vecf16_dot.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for Vecf16Dot { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; - fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 { + fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> Distance { lhs.operator_dot(rhs) } } diff --git a/crates/base/src/operator/vecf16_l2.rs b/crates/base/src/operator/vecf16_l2.rs index f691c7c71..a792b78be 100644 --- a/crates/base/src/operator/vecf16_l2.rs +++ b/crates/base/src/operator/vecf16_l2.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for Vecf16L2 { const DISTANCE_KIND: DistanceKind = DistanceKind::L2; - fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 { + fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> Distance { lhs.operator_l2(rhs) } } diff --git a/crates/base/src/operator/vecf32_dot.rs b/crates/base/src/operator/vecf32_dot.rs index 1639b9fa7..86f2318f9 100644 --- a/crates/base/src/operator/vecf32_dot.rs +++ b/crates/base/src/operator/vecf32_dot.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for Vecf32Dot { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; - fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 { + fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> Distance { lhs.operator_dot(rhs) } } diff --git a/crates/base/src/operator/vecf32_l2.rs b/crates/base/src/operator/vecf32_l2.rs index a73e62f25..ce7e5bbe3 100644 --- a/crates/base/src/operator/vecf32_l2.rs +++ b/crates/base/src/operator/vecf32_l2.rs @@ -1,6 +1,5 @@ use crate::distance::*; use crate::operator::*; -use crate::scalar::*; use crate::vector::*; #[derive(Debug, Clone, Copy)] @@ -11,7 +10,7 @@ impl Operator for Vecf32L2 { const DISTANCE_KIND: DistanceKind = DistanceKind::L2; - fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 { + fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> Distance { lhs.operator_l2(rhs) } } diff --git a/crates/base/src/pod.rs b/crates/base/src/pod.rs index 8b9c58f0c..dbd14706f 100644 --- a/crates/base/src/pod.rs +++ b/crates/base/src/pod.rs @@ -1,6 +1,8 @@ // This module is a workaround for orphan rules -use crate::scalar::{F16, F32, I8}; +use crate::distance::Distance; +use crate::scalar::{F16, F32}; +use crate::search::Payload; /// # Safety /// @@ -26,13 +28,11 @@ unsafe impl Pod for isize {} unsafe impl Pod for f32 {} unsafe impl Pod for f64 {} -unsafe impl Pod for I8 {} unsafe impl Pod for F16 {} unsafe impl Pod for F32 {} -unsafe impl Pod for (F32, u32) {} - -unsafe impl Pod for crate::search::Payload {} +unsafe impl Pod for Payload {} +unsafe impl Pod for Distance {} pub fn bytes_of(t: &T) -> &[u8] { unsafe { core::slice::from_raw_parts(std::ptr::addr_of!(*t) as *const u8, size_of::()) } diff --git a/crates/base/src/scalar/i8.rs b/crates/base/src/scalar/i8.rs deleted file mode 100644 index 408d105f2..000000000 --- a/crates/base/src/scalar/i8.rs +++ /dev/null @@ -1,76 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; -use std::fmt::{Debug, Display}; - -use super::F32; - -#[derive(Clone, Copy, Default, Serialize, Deserialize)] -#[repr(transparent)] -#[serde(transparent)] -pub struct I8(pub i8); - -impl Debug for I8 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Debug::fmt(&self.0, f) - } -} - -impl Display for I8 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - -impl PartialEq for I8 { - fn eq(&self, other: &Self) -> bool { - self.0.cmp(&other.0) == Ordering::Equal - } -} - -impl Eq for I8 {} - -impl PartialOrd for I8 { - #[inline(always)] - fn partial_cmp(&self, other: &Self) -> Option { - Some(Ord::cmp(self, other)) - } -} - -impl Ord for I8 { - #[inline(always)] - fn cmp(&self, other: &Self) -> Ordering { - self.0.cmp(&other.0) - } -} - -impl From for I8 { - fn from(value: i8) -> Self { - Self(value) - } -} - -impl From for i8 { - fn from(I8(value): I8) -> Self { - value - } -} - -impl From for I8 { - fn from(F32(value): F32) -> Self { - // Because F32 may be out of range of i8 [-128, 127], so we can't use to_int_unchecked here. - Self(value as i8) - } -} - -impl From for F32 { - fn from(val: I8) -> Self { - F32(val.0 as f32) - } -} - -impl I8 { - #[inline(always)] - pub fn to_f32(self) -> F32 { - F32(self.0 as f32) - } -} diff --git a/crates/base/src/scalar/mod.rs b/crates/base/src/scalar/mod.rs index 593b1a4a5..05da93f97 100644 --- a/crates/base/src/scalar/mod.rs +++ b/crates/base/src/scalar/mod.rs @@ -1,12 +1,10 @@ mod f32; mod half_f16; -mod i8; use std::iter::Sum; pub use f32::F32; pub use half_f16::F16; -pub use i8::I8; pub trait ScalarLike: Copy diff --git a/crates/base/src/search.rs b/crates/base/src/search.rs index 44ac963e5..316fd48d4 100644 --- a/crates/base/src/search.rs +++ b/crates/base/src/search.rs @@ -1,5 +1,5 @@ use crate::always_equal::AlwaysEqual; -use crate::scalar::F32; +use crate::distance::Distance; use crate::vector::VectorOwned; use serde::{Deserialize, Serialize}; use std::any::Any; @@ -69,7 +69,7 @@ impl Payload { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Element { - pub distance: F32, + pub distance: Distance, pub payload: AlwaysEqual, } @@ -90,7 +90,7 @@ pub trait Source { } pub trait RerankerPop { - fn pop(&mut self) -> Option<(F32, u32, T)>; + fn pop(&mut self) -> Option<(Distance, u32, T)>; } pub trait RerankerPush { @@ -100,7 +100,7 @@ pub trait RerankerPush { pub trait FlatReranker: RerankerPop {} impl<'a, T> RerankerPop for Box + 'a> { - fn pop(&mut self) -> Option<(F32, u32, T)> { + fn pop(&mut self) -> Option<(Distance, u32, T)> { self.as_mut().pop() } } diff --git a/crates/base/src/vector/bvector.rs b/crates/base/src/vector/bvector.rs index 38881c999..fed49d3c3 100644 --- a/crates/base/src/vector/bvector.rs +++ b/crates/base/src/vector/bvector.rs @@ -1,5 +1,6 @@ use std::ops::{Bound, RangeBounds}; +use crate::distance::Distance; use crate::scalar::F32; use crate::vector::{Vecf32Owned, VectorBorrowed, VectorKind, VectorOwned}; use num_traits::Float; @@ -151,28 +152,28 @@ impl<'a> VectorBorrowed for BVectorBorrowed<'a> { } #[inline(always)] - fn operator_dot(self, rhs: Self) -> F32 { - dot(self, rhs) * (-1.0) + fn operator_dot(self, rhs: Self) -> Distance { + Distance::from(-dot(self, rhs).0) } #[inline(always)] - fn operator_l2(self, _: Self) -> F32 { + fn operator_l2(self, _: Self) -> Distance { unimplemented!() } #[inline(always)] - fn operator_cos(self, _: Self) -> F32 { + fn operator_cos(self, _: Self) -> Distance { unimplemented!() } #[inline(always)] - fn operator_hamming(self, rhs: Self) -> F32 { - hamming(self, rhs) + fn operator_hamming(self, rhs: Self) -> Distance { + Distance::from(hamming(self, rhs).0) } #[inline(always)] - fn operator_jaccard(self, rhs: Self) -> F32 { - F32(1.0) - jaccard(self, rhs) + fn operator_jaccard(self, rhs: Self) -> Distance { + Distance::from(1.0 - jaccard(self, rhs).0) } #[inline(always)] diff --git a/crates/base/src/vector/mod.rs b/crates/base/src/vector/mod.rs index 628f6a1ae..37472cce8 100644 --- a/crates/base/src/vector/mod.rs +++ b/crates/base/src/vector/mod.rs @@ -8,6 +8,7 @@ pub use svecf32::{SVecf32Borrowed, SVecf32Owned}; pub use vecf16::{Vecf16Borrowed, Vecf16Owned}; pub use vecf32::{Vecf32Borrowed, Vecf32Owned}; +use crate::distance::Distance; use crate::scalar::ScalarLike; use crate::scalar::F32; use serde::{Deserialize, Serialize}; @@ -46,15 +47,15 @@ pub trait VectorBorrowed: Copy + PartialEq + PartialOrd { fn norm(&self) -> F32; - fn operator_dot(self, rhs: Self) -> F32; + fn operator_dot(self, rhs: Self) -> Distance; - fn operator_l2(self, rhs: Self) -> F32; + fn operator_l2(self, rhs: Self) -> Distance; - fn operator_cos(self, rhs: Self) -> F32; + fn operator_cos(self, rhs: Self) -> Distance; - fn operator_hamming(self, rhs: Self) -> F32; + fn operator_hamming(self, rhs: Self) -> Distance; - fn operator_jaccard(self, rhs: Self) -> F32; + fn operator_jaccard(self, rhs: Self) -> Distance; fn function_normalize(&self) -> Self::Owned; diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index cf11ccb3b..b59f3cdf6 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -1,3 +1,4 @@ +use crate::distance::Distance; use crate::scalar::F32; use crate::vector::{VectorBorrowed, VectorKind, VectorOwned}; use num_traits::{Float, Zero}; @@ -186,27 +187,27 @@ impl<'a> VectorBorrowed for SVecf32Borrowed<'a> { } #[inline(always)] - fn operator_dot(self, rhs: Self) -> F32 { - dot(self, rhs) * (-1.0) + fn operator_dot(self, rhs: Self) -> Distance { + Distance::from(-dot(self, rhs).0) } #[inline(always)] - fn operator_l2(self, rhs: Self) -> F32 { - sl2(self, rhs) + fn operator_l2(self, rhs: Self) -> Distance { + Distance::from(sl2(self, rhs).0) } #[inline(always)] - fn operator_cos(self, rhs: Self) -> F32 { - F32(1.0) - dot(self, rhs) / (self.norm() * rhs.norm()) + fn operator_cos(self, rhs: Self) -> Distance { + Distance::from(-(dot(self, rhs) / (self.norm() * rhs.norm())).0) } #[inline(always)] - fn operator_hamming(self, _: Self) -> F32 { + fn operator_hamming(self, _: Self) -> Distance { unimplemented!() } #[inline(always)] - fn operator_jaccard(self, _: Self) -> F32 { + fn operator_jaccard(self, _: Self) -> Distance { unimplemented!() } diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index 1a1a9741e..c6fb01e62 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -1,4 +1,5 @@ use super::{VectorBorrowed, VectorKind, VectorOwned}; +use crate::distance::Distance; use crate::scalar::{ScalarLike, F16, F32}; use num_traits::{Float, Zero}; use serde::{Deserialize, Serialize}; @@ -110,27 +111,27 @@ impl<'a> VectorBorrowed for Vecf16Borrowed<'a> { } #[inline(always)] - fn operator_dot(self, rhs: Self) -> F32 { - dot(self.slice(), rhs.slice()) * (-1.0) + fn operator_dot(self, rhs: Self) -> Distance { + Distance::from(-dot(self.slice(), rhs.slice()).0) } #[inline(always)] - fn operator_l2(self, rhs: Self) -> F32 { - sl2(self.slice(), rhs.slice()) + fn operator_l2(self, rhs: Self) -> Distance { + Distance::from(sl2(self.slice(), rhs.slice()).0) } #[inline(always)] - fn operator_cos(self, rhs: Self) -> F32 { - F32(1.0) - dot(self.slice(), rhs.slice()) / (self.norm() * rhs.norm()) + fn operator_cos(self, rhs: Self) -> Distance { + Distance::from(1.0 - (dot(self.slice(), rhs.slice()) / (self.norm() * rhs.norm())).0) } #[inline(always)] - fn operator_hamming(self, _: Self) -> F32 { + fn operator_hamming(self, _: Self) -> Distance { unimplemented!() } #[inline(always)] - fn operator_jaccard(self, _: Self) -> F32 { + fn operator_jaccard(self, _: Self) -> Distance { unimplemented!() } diff --git a/crates/base/src/vector/vecf32.rs b/crates/base/src/vector/vecf32.rs index 592ec73fe..0398d11fd 100644 --- a/crates/base/src/vector/vecf32.rs +++ b/crates/base/src/vector/vecf32.rs @@ -1,4 +1,5 @@ use super::{VectorBorrowed, VectorKind, VectorOwned}; +use crate::distance::Distance; use crate::scalar::F32; use num_traits::{Float, Zero}; use serde::{Deserialize, Serialize}; @@ -110,27 +111,27 @@ impl<'a> VectorBorrowed for Vecf32Borrowed<'a> { } #[inline(always)] - fn operator_dot(self, rhs: Self) -> F32 { - dot(self.slice(), rhs.slice()) * (-1.0) + fn operator_dot(self, rhs: Self) -> Distance { + Distance::from(-dot(self.slice(), rhs.slice()).0) } #[inline(always)] - fn operator_l2(self, rhs: Self) -> F32 { - sl2(self.slice(), rhs.slice()) + fn operator_l2(self, rhs: Self) -> Distance { + Distance::from(sl2(self.slice(), rhs.slice()).0) } #[inline(always)] - fn operator_cos(self, rhs: Self) -> F32 { - F32(1.0) - dot(self.slice(), rhs.slice()) / (self.norm() * rhs.norm()) + fn operator_cos(self, rhs: Self) -> Distance { + Distance::from(1.0 - (dot(self.slice(), rhs.slice()) / (self.norm() * rhs.norm())).0) } #[inline(always)] - fn operator_hamming(self, _: Self) -> F32 { + fn operator_hamming(self, _: Self) -> Distance { unimplemented!() } #[inline(always)] - fn operator_jaccard(self, _: Self) -> F32 { + fn operator_jaccard(self, _: Self) -> Distance { unimplemented!() } diff --git a/crates/base/src/worker.rs b/crates/base/src/worker.rs index 6e09fece1..7761c9c7d 100644 --- a/crates/base/src/worker.rs +++ b/crates/base/src/worker.rs @@ -1,5 +1,5 @@ +use crate::distance::Distance; use crate::index::*; -use crate::scalar::F32; use crate::search::*; use crate::vector::*; @@ -32,7 +32,7 @@ pub trait ViewVbaseOperations { &'a self, vector: &'a OwnedVector, opts: &'a SearchOptions, - ) -> Result + 'a>, VbaseError>; + ) -> Result + 'a>, VbaseError>; } pub trait ViewListOperations { diff --git a/crates/graph/src/prune.rs b/crates/graph/src/prune.rs index e5ff619c8..b0a96f8ed 100644 --- a/crates/graph/src/prune.rs +++ b/crates/graph/src/prune.rs @@ -1,10 +1,10 @@ -use base::scalar::F32; +use base::distance::Distance; pub fn prune( - dist: impl Fn(u32, u32) -> F32, + dist: impl Fn(u32, u32) -> Distance, u: u32, - edges: &mut Vec<(F32, u32)>, - add: &[(F32, u32)], + edges: &mut Vec<(Distance, u32)>, + add: &[(Distance, u32)], m: u32, ) { let mut trace = add.to_vec(); @@ -30,14 +30,13 @@ pub fn prune( } pub fn robust_prune( - dist: impl Fn(u32, u32) -> F32, + dist: impl Fn(u32, u32) -> Distance, u: u32, - edges: &mut Vec<(F32, u32)>, - add: &[(F32, u32)], + edges: &mut Vec<(Distance, u32)>, + add: &[(Distance, u32)], alpha: f32, m: u32, ) { - let alpha = F32(alpha); // V ← (V ∪ Nout(p)) \ {p} let mut trace = add.to_vec(); trace.extend(edges.as_slice()); @@ -54,7 +53,7 @@ pub fn robust_prune( let check = res .iter() .map(|&(_, v)| dist(u, v)) - .all(|dist| alpha * dist > dis_u); + .all(|dist| f32::from(dist) * alpha > f32::from(dis_u)); if check { res.push((dis_u, u)); } diff --git a/crates/graph/src/search.rs b/crates/graph/src/search.rs index 024cca752..7db1d9dbf 100644 --- a/crates/graph/src/search.rs +++ b/crates/graph/src/search.rs @@ -1,7 +1,7 @@ use crate::visited::VisitedGuard; use crate::visited::VisitedPool; use base::always_equal::AlwaysEqual; -use base::scalar::F32; +use base::distance::Distance; use base::search::RerankerPop; use base::search::RerankerPush; use std::cmp::Reverse; @@ -52,17 +52,17 @@ impl Results { } pub fn search( - dist: impl Fn(u32) -> F32, + dist: impl Fn(u32) -> Distance, read_outs: impl Fn(u32) -> E, visited: &mut VisitedGuard, s: u32, ef_construction: u32, -) -> Vec<(F32, u32)> +) -> Vec<(Distance, u32)> where E: Iterator, { let mut visited = visited.fetch_checker(); - let mut candidates = BinaryHeap::>::new(); + let mut candidates = BinaryHeap::>::new(); let mut results = Results::new(ef_construction as _); { let dis_s = dist(s); @@ -89,17 +89,17 @@ where } pub fn search_returning_trace( - dist: impl Fn(u32) -> F32, + dist: impl Fn(u32) -> Distance, read_outs: impl Fn(u32) -> E, visited: &mut VisitedGuard, s: u32, ef_construction: u32, -) -> (Vec<(F32, u32)>, Vec<(F32, u32)>) +) -> (Vec<(Distance, u32)>, Vec<(Distance, u32)>) where E: Iterator, { let mut visited = visited.fetch_checker(); - let mut reranker = BinaryHeap::>::new(); + let mut reranker = BinaryHeap::>::new(); let mut results = Results::new(ef_construction as _); let mut trace = Vec::new(); { @@ -131,7 +131,7 @@ pub fn vbase_internal<'a, G, E, T>( visited: &'a VisitedPool, s: u32, mut reranker: G, -) -> impl Iterator + 'a +) -> impl Iterator + 'a where G: RerankerPush + RerankerPop<(E, T)> + 'a, E: Iterator, @@ -160,7 +160,7 @@ pub fn vbase_generic<'a, G, E, T>( s: u32, reranker: G, ef_search: u32, -) -> impl Iterator + 'a +) -> impl Iterator + 'a where G: RerankerPush + RerankerPop<(E, T)> + 'a, E: Iterator, diff --git a/crates/hnsw/src/lib.rs b/crates/hnsw/src/lib.rs index 68b9411e2..12c9a62ce 100644 --- a/crates/hnsw/src/lib.rs +++ b/crates/hnsw/src/lib.rs @@ -2,16 +2,15 @@ #![allow(clippy::type_complexity)] use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::*; use base::operator::*; -use base::scalar::F32; use base::search::*; use base::vector::VectorBorrowed; use common::json::Json; use common::mmap_array::MmapArray; use common::remap::RemappedCollection; use graph::visited::VisitedPool; -use num_traits::Float; use parking_lot::RwLock; use quantization::operator::OperatorQuantization; use quantization::Quantization; @@ -34,9 +33,9 @@ pub struct Hnsw { quantization: Quantization, payloads: MmapArray, base_graph_outs: MmapArray, - base_graph_weights: MmapArray, + base_graph_weights: MmapArray, hyper_graph_outs: MmapArray, - hyper_graph_weights: MmapArray, + hyper_graph_weights: MmapArray, m: Json, s: Option, visited: VisitedPool, @@ -214,9 +213,9 @@ fn from_main( |u| remapped.skip(u), |u, level| { if level == 0 { - Box::new(base_edges(main, u)) as Box> + Box::new(base_edges(main, u)) as Box> } else { - Box::new(hyper_edges(main, u, level)) as Box> + Box::new(hyper_edges(main, u, level)) as Box> } }, remapped.len(), @@ -320,7 +319,7 @@ fn open(path: impl AsRef) -> Hnsw { } fn fast_search( - dist: impl Fn(u32) -> F32, + dist: impl Fn(u32) -> Distance, read_outs: impl Fn(u32, u8) -> E, levels: RangeInclusive, u: u32, @@ -346,7 +345,7 @@ where u } -fn fresh(n: u32, m: u32) -> Vec>>> { +fn fresh(n: u32, m: u32) -> Vec>>> { let mut g = Vec::with_capacity(n as usize); for u in 0..n { let l = hierarchy_for_a_vertex(m, u); @@ -358,14 +357,14 @@ fn fresh(n: u32, m: u32) -> Vec>>> { } fn patch_deletions( - dist: impl Fn(u32, u32) -> F32 + Copy + Sync, + dist: impl Fn(u32, u32) -> Distance + Copy + Sync, skip: impl Fn(u32) -> bool + Sync, read_edges: impl Fn(u32, u8) -> E + Sync, n: u32, m: u32, - g: &mut [Vec>>], + g: &mut [Vec>>], ) where - E: Iterator, + E: Iterator, { (0..n).into_par_iter().for_each(|u| { rayon::check(); @@ -391,12 +390,12 @@ fn patch_deletions( } fn patch_insertions( - dist: impl Fn(u32, u32) -> F32 + Copy + Sync, + dist: impl Fn(u32, u32) -> Distance + Copy + Sync, skip: impl Fn(u32) -> bool + Sync, n: u32, ef_construction: u32, m: u32, - g: &mut [Vec>>], + g: &mut [Vec>>], ) { #[repr(C)] #[derive(Debug, Clone, Copy)] @@ -561,13 +560,13 @@ fn patch_insertions( }); } -fn finish(g: &mut [Vec>>], m: u32) { +fn finish(g: &mut [Vec>>], m: u32) { for u in 0..g.len() as u32 { let l = hierarchy_for_a_vertex(m, u); for j in 0..l { g[u as usize][j as usize].get_mut().resize( capacity_for_a_hierarchy(m, j) as usize, - (F32::infinity(), u32::MAX), + (Distance::INFINITY, u32::MAX), ); } } @@ -591,7 +590,10 @@ fn capacity_for_a_hierarchy(m: u32, level: u8) -> u32 { } } -fn base_edges(hnsw: &Hnsw, u: u32) -> impl Iterator + '_ { +fn base_edges( + hnsw: &Hnsw, + u: u32, +) -> impl Iterator + '_ { let m = *hnsw.m; let offset = 2 * m as usize * u as usize; let edges_outs = hnsw.base_graph_outs[offset..offset + 2 * m as usize] @@ -617,7 +619,7 @@ fn hyper_edges( hnsw: &Hnsw, u: u32, level: u8, -) -> impl Iterator + '_ { +) -> impl Iterator + '_ { let m = *hnsw.m; let offset = { let mut offset = 0; diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index 739c78a8f..9a92b5b74 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -12,9 +12,9 @@ use self::segment::sealed::SealedSegment; use crate::optimizing::Optimizing; use crate::utils::tournament_tree::LoserTree; use arc_swap::ArcSwap; +use base::distance::Distance; use base::index::*; use base::operator::*; -use base::scalar::F32; use base::search::*; use base::vector::*; use common::clean::clean; @@ -388,7 +388,7 @@ impl IndexView { &'a self, vector: Borrowed<'a, O>, opts: &'a SearchOptions, - ) -> Result + 'a, VbaseError> { + ) -> Result + 'a, VbaseError> { if self.options.vector.dims != vector.dims() { return Err(VbaseError::InvalidVector); } diff --git a/crates/inverted/src/lib.rs b/crates/inverted/src/lib.rs index 9c9eff3bb..12086a724 100644 --- a/crates/inverted/src/lib.rs +++ b/crates/inverted/src/lib.rs @@ -4,6 +4,7 @@ pub mod operator; use self::operator::OperatorInvertedIndex; use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::{IndexOptions, SearchOptions}; use base::operator::{Borrowed, Owned}; use base::scalar::{ScalarLike, F32}; @@ -55,15 +56,15 @@ impl InvertedIndex { doc_score[self.indexes[i] as usize] += self.scores[i] * val; } } - let mut candidates: BinaryHeap<(F32, AlwaysEqual)> = doc_score + let mut candidates: BinaryHeap<(Distance, AlwaysEqual)> = doc_score .iter() .enumerate() - .map(|(i, score)| (*score, AlwaysEqual(i as u32))) + .map(|(i, score)| (Distance::from(-score.0), AlwaysEqual(i as u32))) .collect::>() .into(); Box::new(std::iter::from_fn(move || { - candidates.pop().map(|(score, AlwaysEqual(u))| Element { - distance: -score, + candidates.pop().map(|(distance, AlwaysEqual(u))| Element { + distance, payload: AlwaysEqual(self.payload(u)), }) })) diff --git a/crates/k_means/src/lloyd.rs b/crates/k_means/src/lloyd.rs index 82c9d4e68..bbccb906a 100644 --- a/crates/k_means/src/lloyd.rs +++ b/crates/k_means/src/lloyd.rs @@ -3,7 +3,8 @@ use common::vec2::Vec2; use num_traits::{Float, Zero}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use stoppable_rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use stoppable_rayon as rayon; pub struct LloydKMeans { dims: usize, diff --git a/crates/pyvectors/src/indexing.rs b/crates/pyvectors/src/indexing.rs index b1aaf9fab..dffbe286c 100644 --- a/crates/pyvectors/src/indexing.rs +++ b/crates/pyvectors/src/indexing.rs @@ -1,7 +1,6 @@ -use base::distance::DistanceKind; +use base::distance::{Distance, DistanceKind}; use base::index::{IndexOptions, SearchOptions}; use base::operator::*; -use base::scalar::F32; use base::search::{Collection, Element, Pointer, Source, Vectors}; use base::vector::*; use std::path::Path; @@ -91,7 +90,7 @@ impl Indexing { &'a self, vector: BorrowedVector<'a>, opts: &'a SearchOptions, - ) -> impl Iterator + 'a { + ) -> impl Iterator + 'a { match (self, vector) { (Self::Vecf32L2(x), BorrowedVector::Vecf32(vector)) => x.vbase(vector, opts), (Self::Vecf32Dot(x), BorrowedVector::Vecf32(vector)) => x.vbase(vector, opts), diff --git a/crates/pyvectors/src/lib.rs b/crates/pyvectors/src/lib.rs index de399440b..7a45003ef 100644 --- a/crates/pyvectors/src/lib.rs +++ b/crates/pyvectors/src/lib.rs @@ -105,7 +105,7 @@ impl Indexing { let (distances, labels) = self .0 .vbase(BorrowedVector::Vecf32(dataset.vector(i)), &search_options) - .map(|(distance, label)| (distance.0, label.as_u64() as i64)) + .map(|(distance, label)| (f32::from(distance), label.as_u64() as i64)) .chain(std::iter::repeat((f32::INFINITY, i64::MAX))) .take(k as usize) .unzip::<_, _, Vec<_>, Vec<_>>(); diff --git a/crates/quantization/src/lib.rs b/crates/quantization/src/lib.rs index 00b08e7ca..e4be3f006 100644 --- a/crates/quantization/src/lib.rs +++ b/crates/quantization/src/lib.rs @@ -18,6 +18,7 @@ use self::product::ProductQuantizer; use self::scalar::ScalarQuantizer; use crate::operator::OperatorQuantization; use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::*; use base::operator::*; use base::scalar::*; @@ -243,7 +244,7 @@ impl Quantization { vectors: &impl Vectors>, preprocessed: &QuantizationPreprocessed, u: u32, - ) -> F32 { + ) -> Distance { match (&*self.train, preprocessed) { (Quantizer::Trivial(x), QuantizationPreprocessed::Trivial(lhs)) => { let rhs = vectors.vector(u); @@ -271,7 +272,7 @@ impl Quantization { &self, preprocessed: &QuantizationPreprocessed, rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, + heap: &mut Vec<(Reverse, AlwaysEqual)>, sq_fast_scan: bool, pq_fast_scan: bool, ) { @@ -301,8 +302,8 @@ impl Quantization { pub fn flat_rerank<'a, T: 'a>( &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (F32, T) + 'a, + heap: Vec<(Reverse, AlwaysEqual)>, + r: impl Fn(u32) -> (Distance, T) + 'a, sq_rerank_size: u32, pq_rerank_size: u32, ) -> Box + 'a> { @@ -314,7 +315,7 @@ impl Quantization { } } - pub fn graph_rerank<'a, T: 'a, R: Fn(u32) -> (F32, T) + 'a>( + pub fn graph_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( &'a self, vector: Borrowed<'a, O>, r: R, diff --git a/crates/quantization/src/operator.rs b/crates/quantization/src/operator.rs index 387b19a13..590716634 100644 --- a/crates/quantization/src/operator.rs +++ b/crates/quantization/src/operator.rs @@ -1,6 +1,7 @@ use crate::product::operator::OperatorProductQuantization; use crate::scalar::operator::OperatorScalarQuantization; use crate::trivial::operator::OperatorTrivialQuantization; +use base::distance::Distance; use base::operator::*; use base::scalar::F32; use num_traits::Zero; @@ -14,11 +15,11 @@ pub trait OperatorQuantizationProcess: Operator { bits: u32, preprocessed: &Self::QuantizationPreprocessed, rhs: impl Fn(usize) -> usize, - ) -> F32; + ) -> Distance; const SUPPORT_FAST_SCAN: bool; fn fast_scan(preprocessed: &Self::QuantizationPreprocessed) -> Vec; - fn fast_scan_resolve(x: F32) -> F32; + fn fast_scan_resolve(x: F32) -> Distance; } macro_rules! unimpl_operator_quantization_process { @@ -32,7 +33,7 @@ macro_rules! unimpl_operator_quantization_process { _: u32, preprocessed: &Self::QuantizationPreprocessed, _: impl Fn(usize) -> usize, - ) -> F32 { + ) -> Distance { match *preprocessed {} } @@ -42,7 +43,7 @@ macro_rules! unimpl_operator_quantization_process { match *preprocessed {} } - fn fast_scan_resolve(_: F32) -> F32 { + fn fast_scan_resolve(_: F32) -> Distance { unimplemented!() } } @@ -58,7 +59,7 @@ impl OperatorQuantizationProcess for Vecf32Dot { bits: u32, preprocessed: &Self::QuantizationPreprocessed, rhs: impl Fn(usize) -> usize, - ) -> F32 { + ) -> Distance { let width = dims.div_ceil(ratio); let xy = { let mut xy = F32::zero(); @@ -67,7 +68,7 @@ impl OperatorQuantizationProcess for Vecf32Dot { } xy }; - F32(0.0) - xy + Distance::from((F32(0.0) - xy).0) } const SUPPORT_FAST_SCAN: bool = true; @@ -76,8 +77,8 @@ impl OperatorQuantizationProcess for Vecf32Dot { preprocessed.clone() } - fn fast_scan_resolve(x: F32) -> F32 { - x * F32(-1.0) + fn fast_scan_resolve(x: F32) -> Distance { + Distance::from(-x.0) } } @@ -90,13 +91,13 @@ impl OperatorQuantizationProcess for Vecf32L2 { bits: u32, preprocessed: &Self::QuantizationPreprocessed, rhs: impl Fn(usize) -> usize, - ) -> F32 { + ) -> Distance { let width = dims.div_ceil(ratio); let mut d2 = F32::zero(); for i in 0..width as usize { d2 += preprocessed[i * (1 << bits) + rhs(i)]; } - d2 + Distance::from(d2.0) } const SUPPORT_FAST_SCAN: bool = true; @@ -105,8 +106,8 @@ impl OperatorQuantizationProcess for Vecf32L2 { preprocessed.clone() } - fn fast_scan_resolve(x: F32) -> F32 { - x + fn fast_scan_resolve(x: F32) -> Distance { + Distance::from(x.0) } } @@ -119,7 +120,7 @@ impl OperatorQuantizationProcess for Vecf16Dot { bits: u32, preprocessed: &Self::QuantizationPreprocessed, rhs: impl Fn(usize) -> usize, - ) -> F32 { + ) -> Distance { let width = dims.div_ceil(ratio); let xy = { let mut xy = F32::zero(); @@ -128,7 +129,7 @@ impl OperatorQuantizationProcess for Vecf16Dot { } xy }; - F32(0.0) - xy + Distance::from(-xy.0) } const SUPPORT_FAST_SCAN: bool = true; @@ -137,8 +138,8 @@ impl OperatorQuantizationProcess for Vecf16Dot { preprocessed.clone() } - fn fast_scan_resolve(x: F32) -> F32 { - x * F32(-1.0) + fn fast_scan_resolve(x: F32) -> Distance { + Distance::from(-x.0) } } @@ -151,13 +152,13 @@ impl OperatorQuantizationProcess for Vecf16L2 { bits: u32, preprocessed: &Self::QuantizationPreprocessed, rhs: impl Fn(usize) -> usize, - ) -> F32 { + ) -> Distance { let width = dims.div_ceil(ratio); let mut d2 = F32::zero(); for i in 0..width as usize { d2 += preprocessed[i * (1 << bits) + rhs(i)]; } - d2 + Distance::from(d2.0) } const SUPPORT_FAST_SCAN: bool = true; @@ -166,8 +167,8 @@ impl OperatorQuantizationProcess for Vecf16L2 { preprocessed.clone() } - fn fast_scan_resolve(x: F32) -> F32 { - x + fn fast_scan_resolve(x: F32) -> Distance { + Distance::from(x.0) } } diff --git a/crates/quantization/src/product/mod.rs b/crates/quantization/src/product/mod.rs index 260c28cab..1dde4a165 100644 --- a/crates/quantization/src/product/mod.rs +++ b/crates/quantization/src/product/mod.rs @@ -4,9 +4,9 @@ use self::operator::OperatorProductQuantization; use crate::reranker::flat::WindowFlatReranker; use crate::reranker::graph::GraphReranker; use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::*; use base::operator::*; -use base::scalar::*; use base::search::*; use base::vector::VectorBorrowed; use common::sample::sample_subvector_transform; @@ -104,7 +104,7 @@ impl ProductQuantizer { ) } - pub fn process(&self, preprocessed: &O::QuantizationPreprocessed, rhs: &[u8]) -> F32 { + pub fn process(&self, preprocessed: &O::QuantizationPreprocessed, rhs: &[u8]) -> Distance { let dims = self.dims; let ratio = self.ratio; match self.bits { @@ -126,7 +126,7 @@ impl ProductQuantizer { &self, preprocessed: &O::QuantizationPreprocessed, rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, + heap: &mut Vec<(Reverse, AlwaysEqual)>, codes: &[u8], packed_codes: &[u8], fast_scan: bool, @@ -193,16 +193,21 @@ impl ProductQuantizer { })); } - pub fn flat_rerank<'a, T: 'a, R: Fn(u32) -> (F32, T) + 'a>( + pub fn flat_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, + heap: Vec<(Reverse, AlwaysEqual)>, r: R, rerank_size: u32, ) -> impl RerankerPop + 'a { WindowFlatReranker::new(heap, r, rerank_size) } - pub fn graph_rerank<'a, T: 'a, C: Fn(u32) -> &'a [u8] + 'a, R: Fn(u32) -> (F32, T) + 'a>( + pub fn graph_rerank< + 'a, + T: 'a, + C: Fn(u32) -> &'a [u8] + 'a, + R: Fn(u32) -> (Distance, T) + 'a, + >( &'a self, vector: Borrowed<'a, O>, c: C, diff --git a/crates/quantization/src/reranker/flat.rs b/crates/quantization/src/reranker/flat.rs index c7a2fff3f..202dee755 100644 --- a/crates/quantization/src/reranker/flat.rs +++ b/crates/quantization/src/reranker/flat.rs @@ -1,17 +1,17 @@ use base::always_equal::AlwaysEqual; -use base::scalar::F32; +use base::distance::Distance; use base::search::*; use std::cmp::Reverse; use std::collections::BinaryHeap; pub struct DisabledFlatReranker { - heap: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, + heap: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, } impl DisabledFlatReranker { - pub fn new(heap: Vec<(Reverse, AlwaysEqual)>, rerank: R) -> Self + pub fn new(heap: Vec<(Reverse, AlwaysEqual)>, rerank: R) -> Self where - R: Fn(u32) -> (F32, T), + R: Fn(u32) -> (Distance, T), { Self { heap: heap @@ -26,7 +26,7 @@ impl DisabledFlatReranker { } impl RerankerPop for DisabledFlatReranker { - fn pop(&mut self) -> Option<(F32, u32, T)> { + fn pop(&mut self) -> Option<(Distance, u32, T)> { let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.heap.pop()?; Some((dis_u, u, pay_u)) } @@ -35,15 +35,15 @@ impl RerankerPop for DisabledFlatReranker { pub struct WindowFlatReranker { rerank: R, size: u32, - heap: BinaryHeap<(Reverse, AlwaysEqual)>, - cache: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, + heap: BinaryHeap<(Reverse, AlwaysEqual)>, + cache: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, } impl WindowFlatReranker where - R: Fn(u32) -> (F32, T), + R: Fn(u32) -> (Distance, T), { - pub fn new(heap: Vec<(Reverse, AlwaysEqual)>, rerank: R, size: u32) -> Self { + pub fn new(heap: Vec<(Reverse, AlwaysEqual)>, rerank: R, size: u32) -> Self { Self { heap: heap.into(), rerank, @@ -55,9 +55,9 @@ where impl RerankerPop for WindowFlatReranker where - R: Fn(u32) -> (F32, T), + R: Fn(u32) -> (Distance, T), { - fn pop(&mut self) -> Option<(F32, u32, T)> { + fn pop(&mut self) -> Option<(Distance, u32, T)> { while !self.heap.is_empty() && self.cache.len() < self.size as _ { let (_, AlwaysEqual(u)) = self.heap.pop().unwrap(); let (dis_u, pay_u) = (self.rerank)(u); diff --git a/crates/quantization/src/reranker/graph.rs b/crates/quantization/src/reranker/graph.rs index 15df3717d..1a513e48a 100644 --- a/crates/quantization/src/reranker/graph.rs +++ b/crates/quantization/src/reranker/graph.rs @@ -1,18 +1,18 @@ use base::always_equal::AlwaysEqual; -use base::scalar::F32; +use base::distance::Distance; use base::search::*; use std::cmp::Reverse; use std::collections::BinaryHeap; pub struct GraphReranker<'a, T, R> { - compute: Option F32 + 'a>>, + compute: Option Distance + 'a>>, rerank: R, - heap: BinaryHeap<(Reverse, AlwaysEqual)>, - cache: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, + heap: BinaryHeap<(Reverse, AlwaysEqual)>, + cache: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, } impl<'a, T, R> GraphReranker<'a, T, R> { - pub fn new(compute: Option F32 + 'a>>, rerank: R) -> Self { + pub fn new(compute: Option Distance + 'a>>, rerank: R) -> Self { Self { compute, rerank, @@ -24,9 +24,9 @@ impl<'a, T, R> GraphReranker<'a, T, R> { impl<'a, T, R> RerankerPop for GraphReranker<'a, T, R> where - R: Fn(u32) -> (F32, T), + R: Fn(u32) -> (Distance, T), { - fn pop(&mut self) -> Option<(F32, u32, T)> { + fn pop(&mut self) -> Option<(Distance, u32, T)> { if self.compute.is_some() { let (_, AlwaysEqual(u)) = self.heap.pop()?; let (dis_u, pay_u) = (self.rerank)(u); @@ -40,7 +40,7 @@ where impl<'a, T, R> RerankerPush for GraphReranker<'a, T, R> where - R: Fn(u32) -> (F32, T), + R: Fn(u32) -> (Distance, T), { fn push(&mut self, u: u32) { if let Some(compute) = self.compute.as_ref() { diff --git a/crates/quantization/src/scalar/mod.rs b/crates/quantization/src/scalar/mod.rs index a1a625a29..a05bf6069 100644 --- a/crates/quantization/src/scalar/mod.rs +++ b/crates/quantization/src/scalar/mod.rs @@ -4,6 +4,7 @@ use self::operator::OperatorScalarQuantization; use crate::reranker::flat::WindowFlatReranker; use crate::reranker::graph::GraphReranker; use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::*; use base::operator::*; use base::scalar::*; @@ -97,7 +98,7 @@ impl ScalarQuantizer { O::scalar_quantization_preprocess(self.dims, self.bits, &self.max, &self.min, lhs) } - pub fn process(&self, preprocessed: &O::QuantizationPreprocessed, rhs: &[u8]) -> F32 { + pub fn process(&self, preprocessed: &O::QuantizationPreprocessed, rhs: &[u8]) -> Distance { let dims = self.dims; match self.bits { 1 => O::quantization_process(dims, 1, 1, preprocessed, |i| { @@ -118,7 +119,7 @@ impl ScalarQuantizer { &self, preprocessed: &O::QuantizationPreprocessed, rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, + heap: &mut Vec<(Reverse, AlwaysEqual)>, codes: &[u8], packed_codes: &[u8], fast_scan: bool, @@ -184,16 +185,21 @@ impl ScalarQuantizer { })); } - pub fn flat_rerank<'a, T: 'a, R: Fn(u32) -> (F32, T) + 'a>( + pub fn flat_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, + heap: Vec<(Reverse, AlwaysEqual)>, r: R, rerank_size: u32, ) -> impl RerankerPop + 'a { WindowFlatReranker::new(heap, r, rerank_size) } - pub fn graph_rerank<'a, T: 'a, C: Fn(u32) -> &'a [u8] + 'a, R: Fn(u32) -> (F32, T) + 'a>( + pub fn graph_rerank< + 'a, + T: 'a, + C: Fn(u32) -> &'a [u8] + 'a, + R: Fn(u32) -> (Distance, T) + 'a, + >( &'a self, vector: Borrowed<'a, O>, c: C, diff --git a/crates/quantization/src/trivial/mod.rs b/crates/quantization/src/trivial/mod.rs index ea13e5412..a3ee1fb7b 100644 --- a/crates/quantization/src/trivial/mod.rs +++ b/crates/quantization/src/trivial/mod.rs @@ -4,11 +4,10 @@ use self::operator::OperatorTrivialQuantization; use crate::reranker::flat::DisabledFlatReranker; use crate::reranker::graph::GraphReranker; use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::*; use base::operator::*; -use base::scalar::*; use base::search::*; -use num_traits::Zero; use serde::Deserialize; use serde::Serialize; use std::cmp::Reverse; @@ -43,7 +42,7 @@ impl TrivialQuantizer { &self, preprocessed: &O::TrivialQuantizationPreprocessed, rhs: Borrowed<'_, O>, - ) -> F32 { + ) -> Distance { O::trivial_quantization_process(preprocessed, rhs) } @@ -51,20 +50,20 @@ impl TrivialQuantizer { &self, _preprocessed: &O::TrivialQuantizationPreprocessed, rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, + heap: &mut Vec<(Reverse, AlwaysEqual)>, ) { - heap.extend(rhs.map(|u| (Reverse(F32::zero()), AlwaysEqual(u)))); + heap.extend(rhs.map(|u| (Reverse(Distance::ZERO), AlwaysEqual(u)))); } pub fn flat_rerank<'a, T: 'a>( &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (F32, T) + 'a, + heap: Vec<(Reverse, AlwaysEqual)>, + r: impl Fn(u32) -> (Distance, T) + 'a, ) -> impl RerankerPop + 'a { DisabledFlatReranker::new(heap, r) } - pub fn graph_rerank<'a, T: 'a, R: Fn(u32) -> (F32, T) + 'a>( + pub fn graph_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( &'a self, _: Borrowed<'a, O>, r: R, diff --git a/crates/quantization/src/trivial/operator.rs b/crates/quantization/src/trivial/operator.rs index ff7e12d85..6af5d5172 100644 --- a/crates/quantization/src/trivial/operator.rs +++ b/crates/quantization/src/trivial/operator.rs @@ -1,5 +1,5 @@ +use base::distance::Distance; use base::operator::*; -use base::scalar::*; use base::vector::VectorBorrowed; use base::vector::VectorOwned; @@ -13,7 +13,7 @@ pub trait OperatorTrivialQuantization: Operator { fn trivial_quantization_process( preprocessed: &Self::TrivialQuantizationPreprocessed, rhs: Borrowed<'_, Self>, - ) -> F32; + ) -> Distance; } impl OperatorTrivialQuantization for O { @@ -28,7 +28,7 @@ impl OperatorTrivialQuantization for O { fn trivial_quantization_process( preprocessed: &Self::TrivialQuantizationPreprocessed, rhs: Borrowed<'_, Self>, - ) -> F32 { + ) -> Distance { O::distance(preprocessed.as_borrowed(), rhs) } } diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index d8d2fc6a6..081c74d1f 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -13,8 +13,7 @@ use base::always_equal::AlwaysEqual; use base::index::{IndexOptions, RabitqIndexingOptions, SearchOptions}; use base::operator::{Borrowed, Owned}; use base::scalar::F32; -use base::search::RerankerPop; -use base::search::{Collection, Element, Payload, Source, Vectors}; +use base::search::{Collection, Element, Payload, RerankerPop, Source, Vectors}; use common::json::Json; use common::mmap_array::MmapArray; use common::remap::RemappedCollection; diff --git a/crates/rabitq/src/quant/error_based.rs b/crates/rabitq/src/quant/error_based.rs index 3ab1e6019..551c8b2f4 100644 --- a/crates/rabitq/src/quant/error_based.rs +++ b/crates/rabitq/src/quant/error_based.rs @@ -1,7 +1,6 @@ use base::always_equal::AlwaysEqual; -use base::scalar::F32; +use base::distance::Distance; use base::search::RerankerPop; -use num_traits::Float; use std::cmp::Reverse; use std::collections::BinaryHeap; @@ -9,18 +8,18 @@ const WINDOW_SIZE: usize = 16; pub struct ErrorBasedReranker { rerank: R, - cache: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, - distance_threshold: F32, - heap: Vec<(Reverse, AlwaysEqual)>, + cache: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, + distance_threshold: Distance, + heap: Vec<(Reverse, AlwaysEqual)>, ranked: bool, } impl ErrorBasedReranker { - pub fn new(heap: Vec<(Reverse, AlwaysEqual)>, rerank: R) -> Self { + pub fn new(heap: Vec<(Reverse, AlwaysEqual)>, rerank: R) -> Self { Self { rerank, cache: BinaryHeap::new(), - distance_threshold: F32::infinity(), + distance_threshold: Distance::INFINITY, heap, ranked: false, } @@ -29,12 +28,12 @@ impl ErrorBasedReranker { impl RerankerPop for ErrorBasedReranker where - R: Fn(u32) -> (F32, T), + R: Fn(u32) -> (Distance, T), { - fn pop(&mut self) -> Option<(F32, u32, T)> { + fn pop(&mut self) -> Option<(Distance, u32, T)> { if !self.ranked { self.ranked = true; - let mut recent_max_accurate = F32::neg_infinity(); + let mut recent_max_accurate = Distance::NEG_INFINITY; let mut count = 0; for &(Reverse(lowerbound), AlwaysEqual(u)) in self.heap.iter() { if lowerbound < self.distance_threshold { @@ -47,7 +46,7 @@ where if count == WINDOW_SIZE { self.distance_threshold = recent_max_accurate; count = 0; - recent_max_accurate = F32::neg_infinity(); + recent_max_accurate = Distance::NEG_INFINITY; } } } diff --git a/crates/rabitq/src/quant/quantization.rs b/crates/rabitq/src/quant/quantization.rs index a83ec5abd..1bef4cb06 100644 --- a/crates/rabitq/src/quant/quantization.rs +++ b/crates/rabitq/src/quant/quantization.rs @@ -1,6 +1,7 @@ use super::quantizer::RabitqQuantizer; use crate::operator::OperatorRabitq; use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::VectorOptions; use base::scalar::F32; use base::search::RerankerPop; @@ -141,7 +142,7 @@ impl Quantization { } } - pub fn process(&self, preprocessed: &QuantizationPreprocessed, u: u32) -> F32 { + pub fn process(&self, preprocessed: &QuantizationPreprocessed, u: u32) -> Distance { match (&*self.train, preprocessed) { (Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => { let bytes = x.bytes() as usize; @@ -161,7 +162,7 @@ impl Quantization { &self, preprocessed: &QuantizationPreprocessed, rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, + heap: &mut Vec<(Reverse, AlwaysEqual)>, rq_epsilon: F32, rq_fast_scan: bool, ) { @@ -181,8 +182,8 @@ impl Quantization { pub fn rerank<'a, T: 'a>( &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (F32, T) + 'a, + heap: Vec<(Reverse, AlwaysEqual)>, + r: impl Fn(u32) -> (Distance, T) + 'a, ) -> impl RerankerPop + 'a { use Quantizer::*; match &*self.train { diff --git a/crates/rabitq/src/quant/quantizer.rs b/crates/rabitq/src/quant/quantizer.rs index ba329f925..95bef2b92 100644 --- a/crates/rabitq/src/quant/quantizer.rs +++ b/crates/rabitq/src/quant/quantizer.rs @@ -1,6 +1,7 @@ use super::error_based::ErrorBasedReranker; use crate::operator::OperatorRabitq; use base::always_equal::AlwaysEqual; +use base::distance::Distance; use base::index::VectorOptions; use base::scalar::F32; use base::search::RerankerPop; @@ -76,9 +77,9 @@ impl RabitqQuantizer { p0: &O::QuantizationPreprocessed0, p1: &O::QuantizationPreprocessed1, (a, b, c, d, e): (F32, F32, F32, F32, &[u8]), - ) -> F32 { + ) -> Distance { let (est, _) = O::rabitq_quantization_process(a, b, c, d, e, p0, p1); - est + Distance::from(est.0) } pub fn process_lowerbound( @@ -87,16 +88,16 @@ impl RabitqQuantizer { p1: &O::QuantizationPreprocessed1, (a, b, c, d, e): (F32, F32, F32, F32, &[u8]), epsilon: F32, - ) -> F32 { + ) -> Distance { let (est, err) = O::rabitq_quantization_process(a, b, c, d, e, p0, p1); - est - err * epsilon + Distance::from((est - err * epsilon).0) } pub fn push_batch( &self, (p0, p1): &(O::QuantizationPreprocessed0, O::QuantizationPreprocessed1), rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, + heap: &mut Vec<(Reverse, AlwaysEqual)>, codes: &[u8], packed_codes: &[u8], meta: &[F32], @@ -146,7 +147,7 @@ impl RabitqQuantizer { let param = res[(u - i) as usize]; let (est, err) = O::rabitq_quantization_process_1(a, b, c, d, p0, param); - est - err * epsilon + Distance::from((est - err * epsilon).0) }), AlwaysEqual(u), ) @@ -200,8 +201,8 @@ impl RabitqQuantizer { pub fn rerank<'a, T: 'a>( &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (F32, T) + 'a, + heap: Vec<(Reverse, AlwaysEqual)>, + r: impl Fn(u32) -> (Distance, T) + 'a, ) -> impl RerankerPop + 'a { ErrorBasedReranker::new(heap, r) } diff --git a/crates/service/src/instance.rs b/crates/service/src/instance.rs index 208514d64..8c060a6b1 100644 --- a/crates/service/src/instance.rs +++ b/crates/service/src/instance.rs @@ -1,7 +1,6 @@ use base::distance::*; use base::index::*; use base::operator::*; -use base::scalar::F32; use base::search::*; use base::vector::*; use base::worker::*; @@ -213,7 +212,7 @@ impl ViewVbaseOperations for InstanceView { &'a self, vector: &'a OwnedVector, opts: &'a SearchOptions, - ) -> Result + 'a>, VbaseError> { + ) -> Result + 'a>, VbaseError> { match (self, vector) { (InstanceView::Vecf32Dot(x), OwnedVector::Vecf32(vector)) => { Ok(Box::new(x.vbase(vector.as_borrowed(), opts)?)) diff --git a/src/datatype/operators_bvector.rs b/src/datatype/operators_bvector.rs index abc5f45f0..b11f60e27 100644 --- a/src/datatype/operators_bvector.rs +++ b/src/datatype/operators_bvector.rs @@ -1,6 +1,5 @@ use crate::datatype::memory_bvector::{BVectorInput, BVectorOutput}; use crate::error::*; -use base::scalar::*; use base::vector::*; use std::num::NonZero; diff --git a/src/datatype/operators_svecf32.rs b/src/datatype/operators_svecf32.rs index 0bed3ef41..e5401eee5 100644 --- a/src/datatype/operators_svecf32.rs +++ b/src/datatype/operators_svecf32.rs @@ -1,6 +1,5 @@ use crate::datatype::memory_svecf32::{SVecf32Input, SVecf32Output}; use crate::error::*; -use base::scalar::*; use base::vector::*; use std::num::NonZero; diff --git a/src/datatype/operators_vecf16.rs b/src/datatype/operators_vecf16.rs index 77e2ea167..f405231c0 100644 --- a/src/datatype/operators_vecf16.rs +++ b/src/datatype/operators_vecf16.rs @@ -1,6 +1,5 @@ use crate::datatype::memory_vecf16::{Vecf16Input, Vecf16Output}; use crate::error::*; -use base::scalar::*; use base::vector::*; use std::num::NonZero; diff --git a/src/datatype/operators_vecf32.rs b/src/datatype/operators_vecf32.rs index b687c57c7..8b4cee48a 100644 --- a/src/datatype/operators_vecf32.rs +++ b/src/datatype/operators_vecf32.rs @@ -1,6 +1,5 @@ use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output}; use crate::error::*; -use base::scalar::*; use base::vector::*; use std::num::NonZero; diff --git a/src/index/am_options.rs b/src/index/am_options.rs index d4c32b569..68452046b 100644 --- a/src/index/am_options.rs +++ b/src/index/am_options.rs @@ -242,10 +242,10 @@ impl Opfamily { (B::BVector(x), _) => O::BVector(x.own()), } } - pub fn process(self, x: F32) -> F32 { + pub fn process(self, x: Distance) -> F32 { match self.pg_distance { - PgDistanceKind::Cos => x + F32(1.0), - _ => x, + PgDistanceKind::Cos => F32(f32::from(x)) + F32(1.0), + _ => F32(f32::from(x)), } } } diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index 3d8d2d11a..49c454d72 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -6,8 +6,8 @@ use crate::error::*; use crate::gucs::internal::{Transport, TRANSPORT}; use crate::ipc::transport::Packet; use crate::utils::cells::PgRefCell; +use base::distance::Distance; use base::index::*; -use base::scalar::F32; use base::search::*; use base::vector::*; use serde::{Deserialize, Serialize}; @@ -326,7 +326,7 @@ defines! { unary flush(handle: Handle) -> (); unary insert(handle: Handle, vector: OwnedVector, pointer: Pointer) -> (); unary delete(handle: Handle, pointer: Pointer) -> (); - stream vbase(handle: Handle, vector: OwnedVector, opts: SearchOptions) -> (F32, Pointer); + stream vbase(handle: Handle, vector: OwnedVector, opts: SearchOptions) -> (Distance, Pointer); stream list(handle: Handle) -> Pointer; unary stat(handle: Handle) -> IndexStat; unary alter(handle: Handle, key: String, value: String) -> ();