Skip to content

Commit

Permalink
refactor: introduce Distance type for comparing
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Aug 22, 2024
1 parent 86bbed2 commit b221777
Show file tree
Hide file tree
Showing 48 changed files with 287 additions and 294 deletions.
67 changes: 67 additions & 0 deletions crates/base/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,70 @@ pub enum DistanceKind {
Hamming,
Jaccard,
}

#[derive(
Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash,
)]
#[repr(transparent)]
#[serde(transparent)]
pub struct Distance(i32);

impl Distance {
pub const ZERO: Self = Self(0);
pub const INFINITY: Self = Self(2139095040);
pub const NEG_INFINITY: Self = Self(-2139095041);

pub fn to_f32(self) -> f32 {
self.into()
}
}

impl From<f32> for Distance {
#[inline(always)]
fn from(value: f32) -> Self {
let bits = value.to_bits() as i32;
Self(bits ^ (((bits >> 31) as u32) >> 1) as i32)
}
}

impl From<Distance> for f32 {
#[inline(always)]
fn from(Distance(bits): Distance) -> Self {
f32::from_bits((bits ^ (((bits >> 31) as u32) >> 1) as i32) as u32)
}
}

#[test]
fn distance_conversions() {
assert_eq!(Distance::from(0.0f32), Distance::ZERO);
assert_eq!(Distance::from(f32::INFINITY), Distance::INFINITY);
assert_eq!(Distance::from(f32::NEG_INFINITY), Distance::NEG_INFINITY);
for i in -100..100 {
let val = (i as f32) * 0.1;
assert_eq!(f32::from(Distance::from(val)).to_bits(), val.to_bits());
}
assert_eq!(
f32::from(Distance::from(0.0f32)).to_bits(),
0.0f32.to_bits()
);
assert_eq!(
f32::from(Distance::from(-0.0f32)).to_bits(),
(-0.0f32).to_bits()
);
assert_eq!(
f32::from(Distance::from(f32::NAN)).to_bits(),
f32::NAN.to_bits()
);
assert_eq!(
f32::from(Distance::from(-f32::NAN)).to_bits(),
(-f32::NAN).to_bits()
);
assert_eq!(
f32::from(Distance::from(f32::INFINITY)).to_bits(),
f32::INFINITY.to_bits()
);
assert_eq!(
f32::from(Distance::from(-f32::INFINITY)).to_bits(),
(-f32::INFINITY).to_bits()
);
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/bvector_dot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for BVectorDot {

const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance {
lhs.operator_dot(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/bvector_hamming.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for BVectorHamming {

const DISTANCE_KIND: DistanceKind = DistanceKind::Hamming;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance {
lhs.operator_hamming(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/bvector_jaccard.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for BVectorJaccard {

const DISTANCE_KIND: DistanceKind = DistanceKind::Jaccard;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance {
lhs.operator_jaccard(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ pub use vecf32_dot::Vecf32Dot;
pub use vecf32_l2::Vecf32L2;

use crate::distance::*;
use crate::scalar::*;
use crate::vector::*;

pub trait Operator: Copy + 'static + Send + Sync {
type VectorOwned: VectorOwned;

const DISTANCE_KIND: DistanceKind;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32;
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance;
}

pub type Owned<T> = <T as Operator>::VectorOwned;
Expand Down
3 changes: 1 addition & 2 deletions crates/base/src/operator/svecf32_dot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for SVecf32Dot {

const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> Distance {
lhs.operator_dot(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/svecf32_l2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for SVecf32L2 {

const DISTANCE_KIND: DistanceKind = DistanceKind::L2;

fn distance(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
fn distance(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> Distance {
lhs.operator_l2(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/vecf16_dot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for Vecf16Dot {

const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;

fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> Distance {
lhs.operator_dot(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/vecf16_l2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for Vecf16L2 {

const DISTANCE_KIND: DistanceKind = DistanceKind::L2;

fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> Distance {
lhs.operator_l2(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/vecf32_dot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for Vecf32Dot {

const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;

fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> Distance {
lhs.operator_dot(rhs)
}
}
3 changes: 1 addition & 2 deletions crates/base/src/operator/vecf32_l2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distance::*;
use crate::operator::*;
use crate::scalar::*;
use crate::vector::*;

#[derive(Debug, Clone, Copy)]
Expand All @@ -11,7 +10,7 @@ impl Operator for Vecf32L2 {

const DISTANCE_KIND: DistanceKind = DistanceKind::L2;

fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> Distance {
lhs.operator_l2(rhs)
}
}
10 changes: 5 additions & 5 deletions crates/base/src/pod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// This module is a workaround for orphan rules

use crate::scalar::{F16, F32, I8};
use crate::distance::Distance;
use crate::scalar::{F16, F32};
use crate::search::Payload;

/// # Safety
///
Expand All @@ -26,13 +28,11 @@ unsafe impl Pod for isize {}
unsafe impl Pod for f32 {}
unsafe impl Pod for f64 {}

unsafe impl Pod for I8 {}
unsafe impl Pod for F16 {}
unsafe impl Pod for F32 {}

unsafe impl Pod for (F32, u32) {}

unsafe impl Pod for crate::search::Payload {}
unsafe impl Pod for Payload {}
unsafe impl Pod for Distance {}

pub fn bytes_of<T: Pod>(t: &T) -> &[u8] {
unsafe { core::slice::from_raw_parts(std::ptr::addr_of!(*t) as *const u8, size_of::<T>()) }
Expand Down
76 changes: 0 additions & 76 deletions crates/base/src/scalar/i8.rs

This file was deleted.

2 changes: 0 additions & 2 deletions crates/base/src/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
mod f32;
mod half_f16;
mod i8;

use std::iter::Sum;

pub use f32::F32;
pub use half_f16::F16;
pub use i8::I8;

pub trait ScalarLike:
Copy
Expand Down
8 changes: 4 additions & 4 deletions crates/base/src/search.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::always_equal::AlwaysEqual;
use crate::scalar::F32;
use crate::distance::Distance;
use crate::vector::VectorOwned;
use serde::{Deserialize, Serialize};
use std::any::Any;
Expand Down Expand Up @@ -69,7 +69,7 @@ impl Payload {

#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Element {
pub distance: F32,
pub distance: Distance,
pub payload: AlwaysEqual<Payload>,
}

Expand All @@ -90,7 +90,7 @@ pub trait Source {
}

pub trait RerankerPop<T> {
fn pop(&mut self) -> Option<(F32, u32, T)>;
fn pop(&mut self) -> Option<(Distance, u32, T)>;
}

pub trait RerankerPush {
Expand All @@ -100,7 +100,7 @@ pub trait RerankerPush {
pub trait FlatReranker<T>: RerankerPop<T> {}

impl<'a, T> RerankerPop<T> for Box<dyn FlatReranker<T> + 'a> {
fn pop(&mut self) -> Option<(F32, u32, T)> {
fn pop(&mut self) -> Option<(Distance, u32, T)> {
self.as_mut().pop()
}
}
Loading

0 comments on commit b221777

Please sign in to comment.