Skip to content

Commit

Permalink
fix: speed gen in preprocessing
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Aug 28, 2024
1 parent dc14439 commit c29308c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 42 deletions.
12 changes: 12 additions & 0 deletions crates/quantization/src/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,15 @@ pub fn quantize<const N: u8>(lut: &[f32]) -> (f32, f32, Vec<u8>) {
pub fn dequantize(sum_1: u32, k: f32, b: f32, sum_x: u16) -> f32 {
(sum_1 as f32) * b + (sum_x as f32) * k
}

// FIXME: the result may not fit in an u16
// FIXME: generated code for AVX512 is bad, and that for AVX2 is not good, so rewrite it
#[detect::multiversion(v4, v3, v2, neon, fallback)]
pub fn reduce_sum_of_x(vector: &[u8]) -> u16 {
let n = vector.len();
let mut sum = 0;
for i in 0..n {
sum += vector[i] as u16;
}
sum
}
80 changes: 40 additions & 40 deletions crates/rabitq/src/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,53 +67,44 @@ impl OperatorRabitq for VectL2<f32> {
}

type Preprocessed0 = (f32, f32, f32, f32);
type Preprocessed1 = ((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>);
type Preprocessed1 = Vec<u8>;

fn preprocess(
vector: &[f32],
) -> (
(f32, f32, f32, f32),
((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>),
) {
fn preprocess(vector: &[f32]) -> ((f32, f32, f32, f32), Vec<u8>) {
use quantization::quantize;
let dis_v_2 = f32::reduce_sum_of_x2(vector);
let (k, b, qvector) = quantization::quantize::quantize::<15>(vector);
let qvector_sum = qvector.iter().fold(0_u32, |x, &y| x + y as u32) as f32;
let blut = binarize(&qvector);
let lut = gen(&qvector);
((dis_v_2, b, k, qvector_sum), (blut, lut))
let (k, b, qvector) = quantize::quantize::<15>(vector);
let qvector_sum = quantize::reduce_sum_of_x(&qvector) as f32;
let lut = gen(qvector);
((dis_v_2, b, k, qvector_sum), lut)
}

fn process(
dis_u_2: f32,
factor_ppc: f32,
factor_ip: f32,
factor_err: f32,
code: &[u8],
p0: &(f32, f32, f32, f32),
p1: &((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>),
_dis_u_2: f32,
_factor_ppc: f32,
_factor_ip: f32,
_factor_err: f32,
_code: &[u8],
_p0: &(f32, f32, f32, f32),
_p1: &Vec<u8>,
) -> Distance {
let abdp = asymmetric_binary_dot_product(code, &p1.0) as u16;
let (rough, _) = rabitq_l2(dis_u_2, factor_ppc, factor_ip, factor_err, *p0, abdp);
Distance::from_f32(rough)
unimplemented!()
}

fn process_lowerbound(
dis_u_2: f32,
factor_ppc: f32,
factor_ip: f32,
factor_err: f32,
code: &[u8],
p0: &(f32, f32, f32, f32),
p1: &((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>),
epsilon: f32,
_dis_u_2: f32,
_factor_ppc: f32,
_factor_ip: f32,
_factor_err: f32,
_code: &[u8],
_p0: &(f32, f32, f32, f32),
_p1: &Vec<u8>,
_epsilon: f32,
) -> Distance {
let abdp = asymmetric_binary_dot_product(code, &p1.0) as u16;
let (rough, err) = rabitq_l2(dis_u_2, factor_ppc, factor_ip, factor_err, *p0, abdp);
Distance::from_f32(rough - epsilon * err)
unimplemented!()
}

fn fscan_preprocess(preprocessed: &((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>)) -> Vec<u8> {
preprocessed.1.clone()
fn fscan_preprocess(preprocessed: &Vec<u8>) -> Vec<u8> {
preprocessed.clone()
}

fn fscan_process_lowerbound(
Expand Down Expand Up @@ -227,6 +218,7 @@ pub fn rabitq_l2(
(rough, err)
}

#[allow(unused)]
fn binarize(vector: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) {
let n = vector.len();
let t0 = {
Expand Down Expand Up @@ -260,15 +252,21 @@ fn binarize(vector: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) {
(t0, t1, t2, t3)
}

fn gen(qvector: &[u8]) -> Vec<u8> {
fn gen(mut qvector: Vec<u8>) -> Vec<u8> {
let dims = qvector.len() as u32;
let t = dims.div_ceil(4);
qvector.resize(qvector.len().next_multiple_of(4), 0);
let mut lut = vec![0u8; t as usize * 16];
for i in 0..t as usize {
let t0 = qvector.get(4 * i + 0).copied().unwrap_or_default();
let t1 = qvector.get(4 * i + 1).copied().unwrap_or_default();
let t2 = qvector.get(4 * i + 2).copied().unwrap_or_default();
let t3 = qvector.get(4 * i + 3).copied().unwrap_or_default();
unsafe {
// this hint is used to skip bound checks
std::hint::assert_unchecked(4 * i + 3 < qvector.len());
std::hint::assert_unchecked(16 * i + 15 < lut.len());
}
let t0 = qvector[4 * i + 0];
let t1 = qvector[4 * i + 1];
let t2 = qvector[4 * i + 2];
let t3 = qvector[4 * i + 3];
lut[16 * i + 0b0000] = 0;
lut[16 * i + 0b0001] = t0;
lut[16 * i + 0b0010] = t1;
Expand All @@ -289,6 +287,7 @@ fn gen(qvector: &[u8]) -> Vec<u8> {
lut
}

#[allow(unused)]
fn binary_dot_product(x: &[u8], y: &[u8]) -> u32 {
assert_eq!(x.len(), y.len());
let n = x.len();
Expand All @@ -299,6 +298,7 @@ fn binary_dot_product(x: &[u8], y: &[u8]) -> u32 {
res
}

#[allow(unused)]
fn asymmetric_binary_dot_product(x: &[u8], y: &(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>)) -> u32 {
let mut res = 0;
res += binary_dot_product(x, &y.0) << 0;
Expand Down
23 changes: 21 additions & 2 deletions crates/rabitq/src/quant/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,29 @@ impl<T, R> ErrorFlatReranker<T, R> {
where
R: Fn(u32) -> (Distance, T),
{
let mut top = BinaryHeap::new();
let mut bounary = Distance::INFINITY;
for (Reverse(low_u), AlwaysEqual(u)) in heap.into_iter() {
if low_u < bounary {
let (dis_u, pay_u) = (rerank)(u);
if dis_u < bounary {
top.push((dis_u, AlwaysEqual(u), AlwaysEqual(pay_u)));
if top.len() > 10 {
top.pop();
}
if top.len() == 10 {
bounary = top.peek().unwrap().0;
}
}
}
}
Self {
rerank,
heap: heap.into(),
cache: BinaryHeap::new(),
heap: BinaryHeap::new(),
cache: top
.into_iter()
.map(|(a, b, c)| (Reverse(a), b, c))
.collect(),
}
}
}
Expand Down

0 comments on commit c29308c

Please sign in to comment.