Skip to content

Commit

Permalink
[experiment] feat: reranking in heap table
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Dec 27, 2024
1 parent 3e23a14 commit 6c796f2
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 154 deletions.
3 changes: 3 additions & 0 deletions src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,4 +366,7 @@ impl Relation {
)
}
}
pub fn raw(&self) -> pgrx::pg_sys::Relation {
self.raw
}
}
1 change: 0 additions & 1 deletion src/vchordrq/algorithm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ pub fn build<V: Vector, T: HeapRelation<V>, R: Reporter>(
let mut chain = Err(metadata);
for i in (0..slices.len()).rev() {
chain = Ok(vectors.push(&VectorTuple {
payload: None,
slice: slices[i].to_vec(),
chain,
}));
Expand Down
28 changes: 2 additions & 26 deletions src/vchordrq/algorithm/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,13 @@ pub fn insert<V: Vector>(
} else {
None
};
let h0_vector = {
let (metadata, slices) = V::vector_split(vector);
let mut chain = Err(metadata);
for i in (0..slices.len()).rev() {
let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple::<V> {
slice: slices[i].to_vec(),
payload: Some(payload.as_u64()),
chain,
})
.unwrap();
chain = Ok(append(
relation.clone(),
meta_tuple.vectors_first,
&tuple,
true,
true,
true,
));
}
chain.ok().unwrap()
};
let h0_payload = payload.as_u64();
let mut list = {
let Some((_, original)) = vectors::vector_dist::<V>(
let Some((_, original)) = vectors::vector_dist_by_mean::<V>(
relation.clone(),
vector,
meta_tuple.mean,
None,
None,
is_residual,
) else {
panic!("data corruption")
Expand Down Expand Up @@ -117,11 +95,10 @@ pub fn insert<V: Vector>(
{
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap();
let Some((Some(dis_u), original)) = vectors::vector_dist::<V>(
let Some((Some(dis_u), original)) = vectors::vector_dist_by_mean::<V>(
relation.clone(),
vector,
mean,
None,
Some(distance_kind),
is_residual,
) else {
Expand All @@ -145,7 +122,6 @@ pub fn insert<V: Vector>(
V::rabitq_code(dims, vector)
};
let tuple = rkyv::to_bytes::<_, 8192>(&Height0Tuple {
mean: h0_vector,
payload: h0_payload,
dis_u_2: code.dis_u_2,
factor_ppc: code.factor_ppc,
Expand Down
28 changes: 11 additions & 17 deletions src/vchordrq/algorithm/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ use std::collections::BinaryHeap;

pub fn scan<V: Vector>(
relation: impl RelationRead + Clone,
vector: V,
raw_vector: V,
distance_kind: DistanceKind,
probes: Vec<u32>,
epsilon: f32,
fetch_vector: impl Fn(u64) -> Option<V> + Copy + 'static,
) -> impl Iterator<Item = (Distance, Pointer)> {
let vector = vector.as_borrowed();
let vector = raw_vector.as_borrowed();
let meta_guard = relation.read(0);
let meta_tuple = meta_guard
.get(1)
Expand All @@ -36,12 +37,11 @@ pub fn scan<V: Vector>(
None
};
let mut lists: Vec<_> = vec![{
let Some((_, original)) = vectors::vector_dist::<V>(
let Some((_, original)) = vectors::vector_dist_by_mean::<V>(
relation.clone(),
vector.as_borrowed(),
meta_tuple.mean,
None,
None,
is_residual,
) else {
panic!("data corruption")
Expand Down Expand Up @@ -98,11 +98,10 @@ pub fn scan<V: Vector>(
std::iter::from_fn(|| {
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap();
let Some((Some(dis_u), original)) = vectors::vector_dist::<V>(
let Some((Some(dis_u), original)) = vectors::vector_dist_by_mean::<V>(
relation.clone(),
vector.as_borrowed(),
mean,
None,
Some(distance_kind),
is_residual,
) else {
Expand Down Expand Up @@ -156,11 +155,7 @@ pub fn scan<V: Vector>(
),
epsilon,
);
results.push((
Reverse(lowerbounds),
AlwaysEqual(h0_tuple.mean),
AlwaysEqual(h0_tuple.payload),
));
results.push((Reverse(lowerbounds), AlwaysEqual(h0_tuple.payload)));
}
current = h0_guard.get_opaque().next;
}
Expand All @@ -169,14 +164,13 @@ pub fn scan<V: Vector>(
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
std::iter::from_fn(move || {
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap();
let Some((Some(dis_u), _)) = vectors::vector_dist::<V>(
relation.clone(),
vector.as_borrowed(),
mean,
Some(pay_u),
let (_, AlwaysEqual(pay_u)) = heap.pop().unwrap();
let Some((Some(dis_u), _)) = vectors::vector_dist_by_fetch::<V>(
raw_vector.as_borrowed(),
pay_u,
Some(distance_kind),
false,
fetch_vector,
) else {
continue;
};
Expand Down
36 changes: 31 additions & 5 deletions src/vchordrq/algorithm/tuples.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::rabitq::{self, Code, Lut};
use crate::vchordrq::types::OwnedVector;
use base::distance::DistanceKind;
use base::distance::{Distance, DistanceKind};
use base::simd::ScalarLike;
use base::vector::{VectOwned, VectorOwned};
use base::vector::{VectOwned, VectorBorrowed, VectorOwned};
use half::f16;
use rkyv::{Archive, ArchiveUnsized, CheckBytes, Deserialize, Serialize};

Expand Down Expand Up @@ -51,6 +51,11 @@ pub trait Vector: VectorOwned {
left: Self::Metadata,
right: Self::Metadata,
) -> f32;
fn distance(
distance_kind: DistanceKind,
lhs: Self::Borrowed<'_>,
rhs: Self::Borrowed<'_>,
) -> Distance;

fn random_projection(vector: Self::Borrowed<'_>) -> Self;

Expand Down Expand Up @@ -120,6 +125,18 @@ impl Vector for VectOwned<f32> {
) -> f32 {
accumulator.1
}
fn distance(
distance_kind: DistanceKind,
lhs: Self::Borrowed<'_>,
rhs: Self::Borrowed<'_>,
) -> Distance {
match distance_kind {
DistanceKind::L2 => lhs.operator_l2(rhs),
DistanceKind::Dot => lhs.operator_dot(rhs),
DistanceKind::Hamming => unreachable!(),
DistanceKind::Jaccard => unreachable!(),
}
}

fn random_projection(vector: Self::Borrowed<'_>) -> Self {
Self::new(crate::projection::project(vector.slice()))
Expand Down Expand Up @@ -201,6 +218,18 @@ impl Vector for VectOwned<f16> {
) -> f32 {
accumulator.1
}
fn distance(
distance_kind: DistanceKind,
lhs: Self::Borrowed<'_>,
rhs: Self::Borrowed<'_>,
) -> Distance {
match distance_kind {
DistanceKind::L2 => lhs.operator_l2(rhs),
DistanceKind::Dot => lhs.operator_dot(rhs),
DistanceKind::Hamming => unreachable!(),
DistanceKind::Jaccard => unreachable!(),
}
}

fn random_projection(vector: Self::Borrowed<'_>) -> Self {
Self::new(f16::vector_from_f32(&crate::projection::project(
Expand Down Expand Up @@ -246,7 +275,6 @@ pub struct MetaTuple {
#[archive(check_bytes)]
pub struct VectorTuple<V: Vector> {
pub slice: Vec<V::Element>,
pub payload: Option<u64>,
pub chain: Result<(u32, u16), V::Metadata>,
}

Expand All @@ -268,8 +296,6 @@ pub struct Height1Tuple {
#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)]
#[archive(check_bytes)]
pub struct Height0Tuple {
// raw vector
pub mean: (u32, u16),
// for height 0 tuple, it's pointers to heap relation
pub payload: u64,
// RaBitQ algorithm
Expand Down
119 changes: 33 additions & 86 deletions src/vchordrq/algorithm/vacuum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,106 +7,53 @@ pub fn vacuum<V: Vector>(
delay: impl Fn(),
callback: impl Fn(Pointer) -> bool,
) {
// step 1: vacuum height_0_tuple
{
let meta_guard = relation.read(0);
let meta_tuple = meta_guard
.get(1)
.map(rkyv::check_archived_root::<MetaTuple>)
.expect("data corruption")
.expect("data corruption");
let mut firsts = vec![meta_tuple.first];
let make_firsts = |firsts| {
let mut results = Vec::new();
for first in firsts {
let mut current = first;
while current != u32::MAX {
let h1_guard = relation.read(current);
for i in 1..=h1_guard.len() {
let h1_tuple = h1_guard
.get(i)
.map(rkyv::check_archived_root::<Height1Tuple>)
.expect("data corruption")
.expect("data corruption");
results.push(h1_tuple.first);
}
current = h1_guard.get_opaque().next;
}
}
results
};
for _ in (1..meta_tuple.height_of_root).rev() {
firsts = make_firsts(firsts);
}
let meta_guard = relation.read(0);
let meta_tuple = meta_guard
.get(1)
.map(rkyv::check_archived_root::<MetaTuple>)
.expect("data corruption")
.expect("data corruption");
let mut firsts = vec![meta_tuple.first];
let make_firsts = |firsts| {
let mut results = Vec::new();
for first in firsts {
let mut current = first;
while current != u32::MAX {
delay();
let mut h0_guard = relation.write(current, false);
let mut reconstruct_removes = Vec::new();
for i in 1..=h0_guard.len() {
let h0_tuple = h0_guard
let h1_guard = relation.read(current);
for i in 1..=h1_guard.len() {
let h1_tuple = h1_guard
.get(i)
.map(rkyv::check_archived_root::<Height0Tuple>)
.map(rkyv::check_archived_root::<Height1Tuple>)
.expect("data corruption")
.expect("data corruption");
if callback(Pointer::new(h0_tuple.payload)) {
reconstruct_removes.push(i);
}
results.push(h1_tuple.first);
}
h0_guard.reconstruct(&reconstruct_removes);
current = h0_guard.get_opaque().next;
current = h1_guard.get_opaque().next;
}
}
results
};
for _ in (1..meta_tuple.height_of_root).rev() {
firsts = make_firsts(firsts);
}
// step 2: vacuum vector_tuple
{
let mut current = {
let meta_guard = relation.read(0);
let meta_tuple = meta_guard
.get(1)
.map(rkyv::check_archived_root::<MetaTuple>)
.expect("data corruption")
.expect("data corruption");
meta_tuple.vectors_first
};
for first in firsts {
let mut current = first;
while current != u32::MAX {
delay();
let read = relation.read(current);
let flag = 'flag: {
for i in 1..=read.len() {
let Some(vector_tuple) = read.get(i) else {
continue;
};
let vector_tuple =
unsafe { rkyv::archived_root::<VectorTuple<V>>(vector_tuple) };
if let Some(payload) = vector_tuple.payload.as_ref().copied() {
if callback(Pointer::new(payload)) {
break 'flag true;
}
}
}
false
};
if flag {
drop(read);
let mut write = relation.write(current, true);
for i in 1..=write.len() {
let Some(vector_tuple) = write.get(i) else {
continue;
};
let vector_tuple =
unsafe { rkyv::archived_root::<VectorTuple<V>>(vector_tuple) };
if let Some(payload) = vector_tuple.payload.as_ref().copied() {
if callback(Pointer::new(payload)) {
write.free(i);
}
}
let mut h0_guard = relation.write(current, false);
let mut reconstruct_removes = Vec::new();
for i in 1..=h0_guard.len() {
let h0_tuple = h0_guard
.get(i)
.map(rkyv::check_archived_root::<Height0Tuple>)
.expect("data corruption")
.expect("data corruption");
if callback(Pointer::new(h0_tuple.payload)) {
reconstruct_removes.push(i);
}
current = write.get_opaque().next;
} else {
current = read.get_opaque().next;
}
h0_guard.reconstruct(&reconstruct_removes);
current = h0_guard.get_opaque().next;
}
}
}
Loading

0 comments on commit 6c796f2

Please sign in to comment.