From 490f901ecb6369e40fc096da2f15edaad98a74e0 Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 28 Aug 2024 19:41:56 +0800 Subject: [PATCH] [experimental] debug tricks Signed-off-by: usamoi --- crates/common/src/mmap_array.rs | 34 +++++++++++++++++++++++++ crates/rabitq/src/quant/quantization.rs | 21 ++++++++++----- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/crates/common/src/mmap_array.rs b/crates/common/src/mmap_array.rs index 157ab426b..2a8132b3d 100644 --- a/crates/common/src/mmap_array.rs +++ b/crates/common/src/mmap_array.rs @@ -1,6 +1,7 @@ use base::pod::Pod; use std::fs::File; use std::io::{BufWriter, Read, Seek, Write}; +use std::marker::PhantomData; use std::ops::Index; use std::ops::{Deref, Range, RangeInclusive}; use std::path::Path; @@ -65,6 +66,39 @@ where } } +pub struct MmapArrayWriter { + file: BufWriter, + len: usize, + _phantom: PhantomData T>, +} + +impl MmapArrayWriter { + pub fn new(path: impl AsRef) -> Self { + let file = std::fs::OpenOptions::new() + .create_new(true) + .read(true) + .append(true) + .open(path) + .unwrap(); + Self { + file: BufWriter::new(file), + len: 0, + _phantom: PhantomData, + } + } + pub fn write(&mut self, x: T) { + self.file.write_all(base::pod::bytes_of(&x)).unwrap(); + self.len += 1; + } + pub fn finish(mut self) { + self.file.write_all(&[0u8; 4096]).unwrap(); + self.file + .write_all(base::pod::bytes_of(&Information { len: self.len })) + .unwrap(); + self.file.flush().unwrap(); + } +} + impl Deref for MmapArray { type Target = [T]; diff --git a/crates/rabitq/src/quant/quantization.rs b/crates/rabitq/src/quant/quantization.rs index 6922a4ec2..8f0df644e 100644 --- a/crates/rabitq/src/quant/quantization.rs +++ b/crates/rabitq/src/quant/quantization.rs @@ -7,10 +7,13 @@ use base::search::RerankerPop; use common::json::Json; use common::mmap_array::MmapArray; use quantization::utils::InfiniteByteChunks; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::ops::Range; use std::path::Path; +use stoppable_rayon as rayon; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound = "")] @@ -45,7 +48,7 @@ impl Quantization { path: impl AsRef, vector_options: VectorOptions, n: u32, - vectors: impl Fn(u32) -> Vec, + vectors: impl Fn(u32) -> Vec + Sync, ) -> Self { std::fs::create_dir(path.as_ref()).unwrap(); fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 { @@ -58,12 +61,17 @@ impl Quantization { b0 | (b1 << 4) } let train = Quantizer::train(vector_options); + let everything = match &train { + Quantizer::Rabitq(x) => (0..n) + .into_par_iter() + .map(|i| x.encode(&vectors(i))) + .collect::>(), + }; let train = Json::create(path.as_ref().join("train"), train); let codes = MmapArray::create(path.as_ref().join("codes"), { match &*train { Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { - let vector = vectors(i); - let (_, _, _, _, codes) = x.encode(&vector); + let (_, _, _, _, codes) = everything[i as usize].clone(); let bytes = x.bytes(); match x.bits() { 1 => InfiniteByteChunks::new(codes.into_iter()) @@ -94,7 +102,8 @@ impl Quantization { let t = x.dims().div_ceil(4); let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { let id = BLOCK_SIZE * block + i as u32; - let (_, _, _, _, e) = x.encode(&vectors(std::cmp::min(id, n - 1))); + let (_, _, _, _, e) = + everything[std::cmp::min(id, n - 1) as usize].clone(); InfiniteByteChunks::new(e.into_iter()) .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) .take(t as usize) @@ -108,8 +117,8 @@ impl Quantization { let meta = MmapArray::create( path.as_ref().join("meta"), match &*train { - Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { - let (a, b, c, d, _) = x.encode(&vectors(i)); + Quantizer::Rabitq(_) => Box::new((0..n).flat_map(|i| { + let (a, b, c, d, _) = everything[i as usize]; [a, b, c, d].into_iter() })), },