Skip to content

Commit

Permalink
add always_equal wrap to heap rerank for item id
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Sep 25, 2024
1 parent ea13ebc commit 86f0532
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
24 changes: 24 additions & 0 deletions src/ord32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,27 @@ impl From<Ord32> for f32 {
x.to_f32()
}
}

#[derive(Debug, Clone, Copy, Default)]
#[repr(transparent)]
pub struct AlwaysEqual<T>(pub T);

impl<T> PartialEq for AlwaysEqual<T> {
fn eq(&self, _: &Self) -> bool {
true
}
}

impl<T> Eq for AlwaysEqual<T> {}

impl<T> PartialOrd for AlwaysEqual<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl<T> Ord for AlwaysEqual<T> {
fn cmp(&self, _: &Self) -> std::cmp::Ordering {
std::cmp::Ordering::Equal
}
}
12 changes: 8 additions & 4 deletions src/rerank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use faer::{Col, ColRef, MatRef};

use crate::consts::WINDOW_SIZE;
use crate::metrics::METRICS;
use crate::ord32::Ord32;
use crate::ord32::{AlwaysEqual, Ord32};
use crate::utils::l2_squared_distance;

pub enum ReRanker {
Expand Down Expand Up @@ -52,7 +52,7 @@ pub trait ReRankerTrait {
pub struct HeapReRanker {
threshold: f32,
topk: usize,
heap: BinaryHeap<(Ord32, u32)>,
heap: BinaryHeap<(Ord32, AlwaysEqual<u32>)>,
query: Col<f32>,
}

Expand All @@ -75,7 +75,8 @@ impl ReRankerTrait for HeapReRanker {
let accurate = l2_squared_distance(&base.col(u as usize), &self.query.as_ref());
precise += 1;
if accurate < self.threshold {
self.heap.push((accurate.into(), map_ids[u as usize]));
self.heap
.push((accurate.into(), AlwaysEqual(map_ids[u as usize])));
if self.heap.len() > self.topk {
self.heap.pop();
}
Expand All @@ -90,7 +91,10 @@ impl ReRankerTrait for HeapReRanker {
}

fn get_result(&self) -> Vec<(f32, u32)> {
self.heap.iter().map(|&(a, b)| (a.into(), b)).collect()
self.heap
.iter()
.map(|&(a, AlwaysEqual(b))| (a.into(), b))
.collect()
}
}

Expand Down

0 comments on commit 86f0532

Please sign in to comment.