diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index c00aab1..da02fed 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -156,6 +156,9 @@ CREATE OPERATOR FAMILY vector_cosine_ops USING vchordrq; CREATE OPERATOR FAMILY halfvec_l2_ops USING vchordrq; CREATE OPERATOR FAMILY halfvec_ip_ops USING vchordrq; CREATE OPERATOR FAMILY halfvec_cosine_ops USING vchordrq; +CREATE OPERATOR FAMILY scalar8_l2_ops USING vchordrq; +CREATE OPERATOR FAMILY scalar8_ip_ops USING vchordrq; +CREATE OPERATOR FAMILY scalar8_cosine_ops USING vchordrq; CREATE OPERATOR FAMILY vector_l2_ops USING Vchordrqfscan; CREATE OPERATOR FAMILY vector_ip_ops USING Vchordrqfscan; @@ -199,6 +202,24 @@ CREATE OPERATOR CLASS halfvec_cosine_ops OPERATOR 2 <<=>> (halfvec, sphere_halfvec) FOR SEARCH, FUNCTION 1 _vchordrq_support_halfvec_cosine_ops(); +CREATE OPERATOR CLASS scalar8_l2_ops + FOR TYPE scalar8 USING vchordrq FAMILY scalar8_l2_ops AS + OPERATOR 1 <-> (scalar8, scalar8) FOR ORDER BY float_ops, + OPERATOR 2 <<->> (scalar8, sphere_scalar8) FOR SEARCH, + FUNCTION 1 _vchordrq_support_scalar8_l2_ops(); + +CREATE OPERATOR CLASS scalar8_ip_ops + FOR TYPE scalar8 USING vchordrq FAMILY scalar8_ip_ops AS + OPERATOR 1 <#> (scalar8, scalar8) FOR ORDER BY float_ops, + OPERATOR 2 <<#>> (scalar8, sphere_scalar8) FOR SEARCH, + FUNCTION 1 _vchordrq_support_scalar8_ip_ops(); + +CREATE OPERATOR CLASS scalar8_cosine_ops + FOR TYPE scalar8 USING vchordrq FAMILY scalar8_cosine_ops AS + OPERATOR 1 <=> (scalar8, scalar8) FOR ORDER BY float_ops, + OPERATOR 2 <<=>> (scalar8, sphere_scalar8) FOR SEARCH, + FUNCTION 1 _vchordrq_support_scalar8_cosine_ops(); + CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING Vchordrqfscan FAMILY vector_l2_ops AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, diff --git a/src/vchordrq/algorithm/insert.rs b/src/vchordrq/algorithm/insert.rs index ee47e9a..36502c6 100644 --- a/src/vchordrq/algorithm/insert.rs +++ b/src/vchordrq/algorithm/insert.rs @@ -1,5 +1,5 @@ use crate::postgres::Relation; -use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound; +use crate::vchordrq::algorithm::rabitq::process_lowerbound; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; use base::always_equal::AlwaysEqual; @@ -31,7 +31,7 @@ pub fn insert( let vector = vector.as_borrowed(); let is_residual = meta_tuple.is_residual; let default_lut = if !is_residual { - Some(V::rabitq_fscan_preprocess(vector)) + Some(V::rabitq_preprocess(vector)) } else { None }; @@ -74,7 +74,7 @@ pub fn insert( let mut results = Vec::new(); { let lut = if is_residual { - &V::rabitq_fscan_preprocess( + &V::rabitq_preprocess( V::residual(vector, list.1.as_ref().map(|x| x.as_borrowed()).unwrap()) .as_borrowed(), ) @@ -91,7 +91,7 @@ pub fn insert( .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( + let lowerbounds = process_lowerbound( distance_kind, dims, lut, diff --git a/src/vchordrq/algorithm/rabitq.rs b/src/vchordrq/algorithm/rabitq.rs index b7b3858..4f0c7c1 100644 --- a/src/vchordrq/algorithm/rabitq.rs +++ b/src/vchordrq/algorithm/rabitq.rs @@ -61,11 +61,11 @@ pub fn code(dims: u32, vector: &[f32]) -> Code { pub type Lut = (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)); -pub fn fscan_preprocess(vector: &[f32]) -> Lut { +pub fn preprocess(vector: &[f32]) -> Lut { use base::simd::quantize; let dis_v_2 = f32::reduce_sum_of_x2(vector); let (k, b, qvector) = quantize::quantize(vector, 15.0); - let qvector_sum = if vector.len() <= 4369 { + let qvector_sum = if qvector.len() <= 4369 { base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 } else { base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32 @@ -73,7 +73,7 @@ pub fn fscan_preprocess(vector: &[f32]) -> Lut { (dis_v_2, b, k, qvector_sum, binarize(&qvector)) } -pub fn fscan_process_lowerbound( +pub fn process_lowerbound( distance_kind: DistanceKind, _dims: u32, lut: &Lut, @@ -104,7 +104,7 @@ pub fn fscan_process_lowerbound( } } -fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { +pub fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { let n = vector.len(); let mut t0 = vec![0u64; n.div_ceil(64)]; let mut t1 = vec![0u64; n.div_ceil(64)]; diff --git a/src/vchordrq/algorithm/scan.rs b/src/vchordrq/algorithm/scan.rs index df4d93d..90abed7 100644 --- a/src/vchordrq/algorithm/scan.rs +++ b/src/vchordrq/algorithm/scan.rs @@ -1,5 +1,5 @@ use crate::postgres::Relation; -use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound; +use crate::vchordrq::algorithm::rabitq::process_lowerbound; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; use base::always_equal::AlwaysEqual; @@ -32,7 +32,7 @@ pub fn scan( let vector = V::random_projection(vector); let is_residual = meta_tuple.is_residual; let default_lut = if !is_residual { - Some(V::rabitq_fscan_preprocess(vector.as_borrowed())) + Some(V::rabitq_preprocess(vector.as_borrowed())) } else { None }; @@ -53,7 +53,7 @@ pub fn scan( let mut results = Vec::new(); for list in lists { let lut = if is_residual { - &V::rabitq_fscan_preprocess( + &V::rabitq_preprocess( V::residual( vector.as_borrowed(), list.1.as_ref().map(|x| x.as_borrowed()).unwrap(), @@ -73,7 +73,7 @@ pub fn scan( .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( + let lowerbounds = process_lowerbound( distance_kind, dims, lut, @@ -125,7 +125,7 @@ pub fn scan( let mut results = Vec::new(); for list in lists { let lut = if is_residual { - &V::rabitq_fscan_preprocess( + &V::rabitq_preprocess( V::residual( vector.as_borrowed(), list.1.as_ref().map(|x| x.as_borrowed()).unwrap(), @@ -145,7 +145,7 @@ pub fn scan( .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( + let lowerbounds = process_lowerbound( distance_kind, dims, lut, diff --git a/src/vchordrq/algorithm/tuples.rs b/src/vchordrq/algorithm/tuples.rs index 40a795c..18c90ee 100644 --- a/src/vchordrq/algorithm/tuples.rs +++ b/src/vchordrq/algorithm/tuples.rs @@ -1,7 +1,9 @@ use super::rabitq::{self, Code, Lut}; +use crate::types::scalar8::Scalar8Owned; use crate::vchordrq::types::OwnedVector; use base::distance::DistanceKind; use base::simd::ScalarLike; +use base::vector::VectorBorrowed; use base::vector::{VectOwned, VectorOwned}; use half::f16; use rkyv::{Archive, ArchiveUnsized, CheckBytes, Deserialize, Serialize}; @@ -56,7 +58,7 @@ pub trait Vector: VectorOwned { fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self; - fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut; + fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut; fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code; @@ -129,8 +131,8 @@ impl Vector for VectOwned { Self::new(ScalarLike::vector_sub(vector.slice(), center.slice())) } - fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { - rabitq::fscan_preprocess(vector.slice()) + fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut { + rabitq::preprocess(vector.slice()) } fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { @@ -212,8 +214,8 @@ impl Vector for VectOwned { Self::new(ScalarLike::vector_sub(vector.slice(), center.slice())) } - fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { - rabitq::fscan_preprocess(&f16::vector_to_f32(vector.slice())) + fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut { + rabitq::preprocess(&f16::vector_to_f32(vector.slice())) } fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { @@ -229,6 +231,147 @@ impl Vector for VectOwned { } } +impl Vector for Scalar8Owned { + type Metadata = (f32, f32, f32, f32); + + type Element = u8; + + fn metadata_from_archived( + archived: &::Archived, + ) -> Self::Metadata { + (archived.0, archived.1, archived.2, archived.3) + } + + fn vector_split(vector: Self::Borrowed<'_>) -> (Self::Metadata, Vec<&[Self::Element]>) { + let code = vector.code(); + ( + ( + vector.sum_of_x2(), + vector.k(), + vector.b(), + vector.sum_of_code(), + ), + match code.len() { + 0..=3840 => vec![code], + 3841..=5120 => vec![&code[..2560], &code[2560..]], + 5121.. => code.chunks(7680).collect(), + }, + ) + } + + fn vector_merge(metadata: Self::Metadata, slice: &[Self::Element]) -> Self { + Scalar8Owned::new( + metadata.0, + metadata.1, + metadata.2, + metadata.3, + slice.to_vec(), + ) + } + + fn from_owned(vector: OwnedVector) -> Self { + match vector { + OwnedVector::Scalar8(x) => x, + _ => unreachable!(), + } + } + + type DistanceAccumulator = (DistanceKind, u32, u32); + + fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator { + (distance_kind, 0, 0) + } + + fn distance_next( + accumulator: &mut Self::DistanceAccumulator, + left: &[Self::Element], + right: &[Self::Element], + ) { + match accumulator.0 { + DistanceKind::L2 => accumulator.1 += base::simd::u8::reduce_sum_of_xy(left, right), + DistanceKind::Dot => accumulator.1 += base::simd::u8::reduce_sum_of_xy(left, right), + DistanceKind::Hamming => unreachable!(), + DistanceKind::Jaccard => unreachable!(), + } + accumulator.2 += left.len() as u32; + } + + fn distance_end( + accumulator: Self::DistanceAccumulator, + (sum_of_x2_u, k_u, b_u, sum_of_code_u): Self::Metadata, + (sum_of_x2_v, k_v, b_v, sum_of_code_v): Self::Metadata, + ) -> f32 { + match accumulator.0 { + DistanceKind::L2 => { + let xy = k_u * k_v * accumulator.1 as f32 + + b_u * b_v * accumulator.2 as f32 + + k_u * b_v * sum_of_code_u + + b_u * k_v * sum_of_code_v; + sum_of_x2_u + sum_of_x2_v - 2.0 * xy + } + DistanceKind::Dot => { + let xy = k_u * k_v * accumulator.1 as f32 + + b_u * b_v * accumulator.2 as f32 + + k_u * b_v * sum_of_code_u + + b_u * k_v * sum_of_code_v; + -xy + } + DistanceKind::Hamming => unreachable!(), + DistanceKind::Jaccard => unreachable!(), + } + } + + fn random_projection(vector: Self::Borrowed<'_>) -> Self { + vector.own() + } + + fn residual(_: Self::Borrowed<'_>, _: Self::Borrowed<'_>) -> Self { + unimplemented!() + } + + fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut { + let dis_v_2 = vector.sum_of_x2(); + let k = vector.k() * 17.0; + let b = vector.b(); + let qvector = vector + .code() + .iter() + .map(|&x| ((x as u32 + 8) / 17) as u8) + .collect::>(); + let qvector_sum = if qvector.len() <= 4369 { + base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + (dis_v_2, b, k, qvector_sum, rabitq::binarize(&qvector)) + } + + fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { + let dequantized = vector + .code() + .iter() + .map(|&x| vector.k() * x as f32 + vector.b()) + .collect::>(); + rabitq::code(dims, &dequantized) + } + + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { + vector + .code() + .iter() + .map(|&x| vector.k() * x as f32 + vector.b()) + .collect() + } + + fn build_from_vecf32(x: &[f32]) -> Self { + let sum_of_x2 = f32::reduce_sum_of_x2(x); + let (k, b, code) = + base::simd::quantize::quantize(f32::vector_to_f32_borrowed(x).as_ref(), 255.0); + let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32; + Self::new(sum_of_x2, k, b, sum_of_code, code) + } +} + #[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] #[archive(check_bytes)] pub struct MetaTuple { diff --git a/src/vchordrq/index/am.rs b/src/vchordrq/index/am.rs index e9182c7..14b2abe 100644 --- a/src/vchordrq/index/am.rs +++ b/src/vchordrq/index/am.rs @@ -1,4 +1,5 @@ use crate::postgres::Relation; +use crate::types::scalar8::Scalar8Owned; use crate::vchordrq::algorithm; use crate::vchordrq::algorithm::build::{HeapRelation, Reporter}; use crate::vchordrq::algorithm::tuples::Vector; @@ -234,6 +235,9 @@ pub unsafe extern "C" fn ambuild( if let Err(errors) = Validate::validate(&vchordrq_options) { pgrx::error!("error while validating options: {}", errors); } + if matches!(vector_options.v, VectorKind::Scalar8) && vchordrq_options.residual_quantization { + pgrx::error!("error while validating options: could not apply residual vector quantization on quantized vectors"); + } let opfamily = unsafe { am_options::opfamily(index) }; let heap_relation = Heap { heap, @@ -258,6 +262,13 @@ pub unsafe extern "C" fn ambuild( index_relation.clone(), reporter.clone(), ), + VectorKind::Scalar8 => algorithm::build::build::( + vector_options, + vchordrq_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ), } if let Some(leader) = unsafe { VchordrqLeader::enter(heap, index, (*index_info).ii_Concurrent) } { @@ -324,6 +335,23 @@ pub unsafe extern "C" fn ambuild( }, ); } + VectorKind::Scalar8 => { + HeapRelation::::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::( + unsafe { Relation::new(index) }, + pointer, + vector, + opfamily.distance_kind(), + true, + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }, + ); + } } } unsafe { pgrx::pgbox::PgBox::::alloc0().into_pg() } @@ -683,6 +711,29 @@ unsafe fn parallel_build( } }); } + VectorKind::Scalar8 => { + HeapRelation::::traverse(&heap_relation, true, |(pointer, vector)| { + algorithm::insert::insert::( + index_relation.clone(), + pointer, + vector, + opfamily.distance_kind(), + true, + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } } unsafe { pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); @@ -763,6 +814,13 @@ pub unsafe extern "C" fn aminsert( opfamily.distance_kind(), false, ), + VectorKind::Scalar8 => algorithm::insert::insert::( + unsafe { Relation::new(index) }, + pointer, + Scalar8Owned::from_owned(vector), + opfamily.distance_kind(), + false, + ), } } false @@ -906,6 +964,13 @@ pub unsafe extern "C" fn ambulkdelete( }, callback, ), + VectorKind::Scalar8 => algorithm::vacuum::vacuum::( + unsafe { Relation::new((*info).index) }, + || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }, + callback, + ), } stats } diff --git a/src/vchordrq/index/am_options.rs b/src/vchordrq/index/am_options.rs index 25fcd0c..2002c10 100644 --- a/src/vchordrq/index/am_options.rs +++ b/src/vchordrq/index/am_options.rs @@ -2,6 +2,8 @@ use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput; use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecOutput; use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; +use crate::datatype::memory_scalar8::Scalar8Input; +use crate::datatype::memory_scalar8::Scalar8Output; use crate::datatype::typmod::Typmod; use crate::vchordrq::types::VchordrqIndexingOptions; use crate::vchordrq::types::VectorOptions; @@ -62,6 +64,9 @@ fn convert_name_to_vd(name: &str) -> Option<(VectorKind, PgDistanceKind)> { Some("halfvec_l2") => Some((VectorKind::Vecf16, PgDistanceKind::L2)), Some("halfvec_ip") => Some((VectorKind::Vecf16, PgDistanceKind::Dot)), Some("halfvec_cosine") => Some((VectorKind::Vecf16, PgDistanceKind::Cos)), + Some("scalar8_l2") => Some((VectorKind::Scalar8, PgDistanceKind::L2)), + Some("scalar8_ip") => Some((VectorKind::Scalar8, PgDistanceKind::Dot)), + Some("scalar8_cosine") => Some((VectorKind::Scalar8, PgDistanceKind::Cos)), _ => None, } } @@ -140,6 +145,10 @@ impl Opfamily { let vector = unsafe { PgvectorHalfvecInput::from_datum(datum, false).unwrap() }; self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed())) } + VectorKind::Scalar8 => { + let vector = unsafe { Scalar8Input::from_datum(datum, false).unwrap() }; + self.preprocess(BorrowedVector::Scalar8(vector.as_borrowed())) + } }; Some(vector) } @@ -161,6 +170,10 @@ impl Opfamily { .get_by_index::(NonZero::new(1).unwrap()) .unwrap() .map(|vector| self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed()))), + VectorKind::Scalar8 => tuple + .get_by_index::(NonZero::new(1).unwrap()) + .unwrap() + .map(|vector| self.preprocess(BorrowedVector::Scalar8(vector.as_borrowed()))), }; let radius = tuple.get_by_index::(NonZero::new(2).unwrap()).unwrap(); (center, radius) @@ -175,6 +188,9 @@ impl Opfamily { (B::Vecf16(x), PgDistanceKind::L2) => O::Vecf16(x.own()), (B::Vecf16(x), PgDistanceKind::Dot) => O::Vecf16(x.own()), (B::Vecf16(x), PgDistanceKind::Cos) => O::Vecf16(x.function_normalize()), + (B::Scalar8(x), PgDistanceKind::L2) => O::Scalar8(x.own()), + (B::Scalar8(x), PgDistanceKind::Dot) => O::Scalar8(x.own()), + (B::Scalar8(x), PgDistanceKind::Cos) => O::Scalar8(x.function_normalize()), } } pub fn process(self, x: Distance) -> f32 { diff --git a/src/vchordrq/index/am_scan.rs b/src/vchordrq/index/am_scan.rs index 1b78ff0..21da8ad 100644 --- a/src/vchordrq/index/am_scan.rs +++ b/src/vchordrq/index/am_scan.rs @@ -1,5 +1,6 @@ use super::am_options::Opfamily; use crate::postgres::Relation; +use crate::types::scalar8::Scalar8Owned; use crate::vchordrq::algorithm::scan::scan; use crate::vchordrq::algorithm::tuples::Vector; use crate::vchordrq::gucs::executing::epsilon; @@ -113,6 +114,25 @@ pub fn scan_next(scanner: &mut Scanner, relation: Relation) -> Option<(Pointer, opfamily: *opfamily, }; } + VectorKind::Scalar8 => { + let vbase = scan::( + relation, + Scalar8Owned::from_owned(vector.clone()), + opfamily.distance_kind(), + probes(), + epsilon(), + ); + *scanner = Scanner::Vbase { + vbase: if let Some(max_scan_tuples) = max_scan_tuples() { + Box::new(vbase.take(max_scan_tuples as usize)) + } else { + Box::new(vbase) + }, + threshold: *threshold, + recheck: *recheck, + opfamily: *opfamily, + }; + } } } else { *scanner = Scanner::Empty {}; diff --git a/src/vchordrq/index/functions.rs b/src/vchordrq/index/functions.rs index 05f348f..316fc59 100644 --- a/src/vchordrq/index/functions.rs +++ b/src/vchordrq/index/functions.rs @@ -1,5 +1,6 @@ use super::am_options; use crate::postgres::Relation; +use crate::types::scalar8::Scalar8Owned; use crate::vchordrq::algorithm::prewarm::prewarm; use crate::vchordrq::types::VectorKind; use base::vector::VectOwned; @@ -26,6 +27,7 @@ fn _vchordrq_prewarm(indexrelid: Oid, height: i32) -> String { let message = match opfamily.vector_kind() { VectorKind::Vecf32 => prewarm::>(relation, height), VectorKind::Vecf16 => prewarm::>(relation, height), + VectorKind::Scalar8 => prewarm::(relation, height), }; unsafe { pgrx::pg_sys::index_close(index, pgrx::pg_sys::ShareLock as _); diff --git a/src/vchordrq/index/opclass.rs b/src/vchordrq/index/opclass.rs index a2dc861..e214749 100644 --- a/src/vchordrq/index/opclass.rs +++ b/src/vchordrq/index/opclass.rs @@ -27,3 +27,18 @@ fn _vchordrq_support_halfvec_ip_ops() -> String { fn _vchordrq_support_halfvec_cosine_ops() -> String { "halfvec_cosine_ops".to_string() } + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrq_support_scalar8_l2_ops() -> String { + "scalar8_l2_ops".to_string() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrq_support_scalar8_ip_ops() -> String { + "scalar8_ip_ops".to_string() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrq_support_scalar8_cosine_ops() -> String { + "scalar8_cosine_ops".to_string() +} diff --git a/src/vchordrq/types.rs b/src/vchordrq/types.rs index 0e1bdc0..33ff3e5 100644 --- a/src/vchordrq/types.rs +++ b/src/vchordrq/types.rs @@ -1,3 +1,4 @@ +use crate::types::scalar8::{Scalar8Borrowed, Scalar8Owned}; use base::distance::DistanceKind; use base::vector::{VectBorrowed, VectOwned}; use half::f16; @@ -103,12 +104,14 @@ impl VchordrqIndexingOptions { pub enum OwnedVector { Vecf32(VectOwned), Vecf16(VectOwned), + Scalar8(Scalar8Owned), } #[derive(Debug, Clone, Copy)] pub enum BorrowedVector<'a> { Vecf32(VectBorrowed<'a, f32>), Vecf16(VectBorrowed<'a, f16>), + Scalar8(Scalar8Borrowed<'a>), } #[repr(u8)] @@ -116,6 +119,7 @@ pub enum BorrowedVector<'a> { pub enum VectorKind { Vecf32, Vecf16, + Scalar8, } #[derive(Debug, Clone, Serialize, Deserialize, Validate)] @@ -132,12 +136,14 @@ pub struct VectorOptions { } impl VectorOptions { - pub fn validate_self(&self) -> Result<(), ValidationError> { + fn validate_self(&self) -> Result<(), ValidationError> { match (self.v, self.d, self.dims) { (VectorKind::Vecf32, DistanceKind::L2, 1..65536) => Ok(()), (VectorKind::Vecf32, DistanceKind::Dot, 1..65536) => Ok(()), (VectorKind::Vecf16, DistanceKind::L2, 1..65536) => Ok(()), (VectorKind::Vecf16, DistanceKind::Dot, 1..65536) => Ok(()), + (VectorKind::Scalar8, DistanceKind::L2, 1..65536) => Ok(()), + (VectorKind::Scalar8, DistanceKind::Dot, 1..65536) => Ok(()), _ => Err(ValidationError::new("not valid vector options")), } }