Skip to content

Commit

Permalink
[experimental] debug tricks
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Aug 28, 2024
1 parent 60ac98b commit 490f901
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
34 changes: 34 additions & 0 deletions crates/common/src/mmap_array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use base::pod::Pod;
use std::fs::File;
use std::io::{BufWriter, Read, Seek, Write};
use std::marker::PhantomData;
use std::ops::Index;
use std::ops::{Deref, Range, RangeInclusive};
use std::path::Path;
Expand Down Expand Up @@ -65,6 +66,39 @@ where
}
}

pub struct MmapArrayWriter<T> {
file: BufWriter<File>,
len: usize,
_phantom: PhantomData<fn(T) -> T>,
}

impl<T: Pod> MmapArrayWriter<T> {
pub fn new(path: impl AsRef<Path>) -> Self {
let file = std::fs::OpenOptions::new()
.create_new(true)
.read(true)
.append(true)
.open(path)
.unwrap();
Self {
file: BufWriter::new(file),
len: 0,
_phantom: PhantomData,
}
}
pub fn write(&mut self, x: T) {
self.file.write_all(base::pod::bytes_of(&x)).unwrap();
self.len += 1;
}
pub fn finish(mut self) {
self.file.write_all(&[0u8; 4096]).unwrap();
self.file
.write_all(base::pod::bytes_of(&Information { len: self.len }))
.unwrap();
self.file.flush().unwrap();
}
}

impl<T> Deref for MmapArray<T> {
type Target = [T];

Expand Down
21 changes: 15 additions & 6 deletions crates/rabitq/src/quant/quantization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ use base::search::RerankerPop;
use common::json::Json;
use common::mmap_array::MmapArray;
use quantization::utils::InfiniteByteChunks;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::ops::Range;
use std::path::Path;
use stoppable_rayon as rayon;

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound = "")]
Expand Down Expand Up @@ -45,7 +48,7 @@ impl<O: OperatorRabitq> Quantization<O> {
path: impl AsRef<Path>,
vector_options: VectorOptions,
n: u32,
vectors: impl Fn(u32) -> Vec<f32>,
vectors: impl Fn(u32) -> Vec<f32> + Sync,
) -> Self {
std::fs::create_dir(path.as_ref()).unwrap();
fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 {
Expand All @@ -58,12 +61,17 @@ impl<O: OperatorRabitq> Quantization<O> {
b0 | (b1 << 4)
}
let train = Quantizer::train(vector_options);
let everything = match &train {
Quantizer::Rabitq(x) => (0..n)
.into_par_iter()
.map(|i| x.encode(&vectors(i)))
.collect::<Vec<_>>(),
};
let train = Json::create(path.as_ref().join("train"), train);
let codes = MmapArray::create(path.as_ref().join("codes"), {
match &*train {
Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| {
let vector = vectors(i);
let (_, _, _, _, codes) = x.encode(&vector);
let (_, _, _, _, codes) = everything[i as usize].clone();
let bytes = x.bytes();
match x.bits() {
1 => InfiniteByteChunks::new(codes.into_iter())
Expand Down Expand Up @@ -94,7 +102,8 @@ impl<O: OperatorRabitq> Quantization<O> {
let t = x.dims().div_ceil(4);
let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| {
let id = BLOCK_SIZE * block + i as u32;
let (_, _, _, _, e) = x.encode(&vectors(std::cmp::min(id, n - 1)));
let (_, _, _, _, e) =
everything[std::cmp::min(id, n - 1) as usize].clone();
InfiniteByteChunks::new(e.into_iter())
.map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3)
.take(t as usize)
Expand All @@ -108,8 +117,8 @@ impl<O: OperatorRabitq> Quantization<O> {
let meta = MmapArray::create(
path.as_ref().join("meta"),
match &*train {
Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| {
let (a, b, c, d, _) = x.encode(&vectors(i));
Quantizer::Rabitq(_) => Box::new((0..n).flat_map(|i| {
let (a, b, c, d, _) = everything[i as usize];
[a, b, c, d].into_iter()
})),
},
Expand Down

0 comments on commit 490f901

Please sign in to comment.