Skip to content

Commit

Permalink
Merge pull request #129 from outbrain/zstd
Browse files Browse the repository at this point in the history
Zstd support
  • Loading branch information
SkBlaz authored Apr 9, 2024
2 parents 7c8cefd + a41f9b5 commit cd270c5
Show file tree
Hide file tree
Showing 15 changed files with 331 additions and 183 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ log = "0.4.18"
env_logger = "0.10.0"
rustc-hash = "1.1.0"
half = "2.3.1"
zstd = "0.13.1"

[build-dependencies]
cbindgen = "0.23.0"
Expand Down
55 changes: 29 additions & 26 deletions src/block_ffm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use crate::model_instance;
use crate::optimizer;
use crate::port_buffer;
use crate::port_buffer::PortBuffer;
use crate::regressor;
use crate::quantization;
use crate::regressor;
use crate::regressor::{BlockCache, FFM_CONTRA_BUF_LEN};

const FFM_STACK_BUF_LEN: usize = 170393;
Expand Down Expand Up @@ -458,8 +458,11 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
contra_fields,
features_present,
ffm,
} = next_cache else {
log::warn!("Unable to downcast cache to BlockFFMCache, executing forward pass without cache");
} = next_cache
else {
log::warn!(
"Unable to downcast cache to BlockFFMCache, executing forward pass without cache"
);
self.forward(further_blocks, fb, pb);
return;
};
Expand Down Expand Up @@ -667,15 +670,18 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
caches: &mut [BlockCache],
) {
let Some((next_cache, further_caches)) = caches.split_first_mut() else {
log::warn!("Expected BlockFFMCache caches, but non available, skipping cache preparation");
log::warn!(
"Expected BlockFFMCache caches, but non available, skipping cache preparation"
);
return;
};

let BlockCache::FFM {
contra_fields,
features_present,
ffm,
} = next_cache else {
} = next_cache
else {
log::warn!("Unable to downcast cache to BlockFFMCache, skipping cache preparation");
return;
};
Expand Down Expand Up @@ -829,32 +835,29 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {

if use_quantization {

let quantized_weights = quantization::quantize_ffm_weights(&self.weights);
block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?;
} else {
if use_quantization {
let quantized_weights = quantization::quantize_ffm_weights(&self.weights);
block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?;
} else {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?;
}
}
block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter, false)?;
Ok(())
}

fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {

if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights);
} else {
if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights);
} else {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?;
}
}

block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader, false)?;
Ok(())
}
Expand All @@ -877,18 +880,18 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
.downcast_mut::<BlockFFM<optimizer::OptimizerSGD>>()
.unwrap();

if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights);
} else {
if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights);
} else {
block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?;
}
}
block_helpers::skip_weights_from_buf::<OptimizerData<L>>(
self.ffm_weights_len as usize,
input_bufreader,
Expand Down Expand Up @@ -1937,7 +1940,7 @@ mod tests {
contra_field_index: mi.ffm_k,
}]);
assert_eq!(spredict2(&mut bg, &fb, &mut pb), 0.5);
assert_eq!(slearn2(&mut bg, &fb, &mut pb, true), 0.5);
assert_eq!(slearn2(&mut bg, &fb, &mut pb, true), 0.62245935);
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/block_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ macro_rules! assert_epsilon {
pub fn read_weights_from_buf<L>(
weights: &mut Vec<L>,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
if weights.is_empty() {
return Err("Loading weights to unallocated weighs buffer".to_string())?;
Expand Down Expand Up @@ -75,7 +75,7 @@ pub fn skip_weights_from_buf<L>(
pub fn write_weights_to_buf<L>(
weights: &Vec<L>,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
if weights.is_empty() {
assert!(false);
Expand Down
24 changes: 11 additions & 13 deletions src/block_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,10 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
return;
};

let BlockCache::LR {
lr,
combo_indexes,
} = next_cache else {
log::warn!("Unable to downcast cache to BlockLRCache, executing forward pass without cache");
let BlockCache::LR { lr, combo_indexes } = next_cache else {
log::warn!(
"Unable to downcast cache to BlockLRCache, executing forward pass without cache"
);
self.forward(further_blocks, fb, pb);
return;
};
Expand Down Expand Up @@ -222,14 +221,13 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
caches: &mut [BlockCache],
) {
let Some((next_cache, further_caches)) = caches.split_first_mut() else {
log::warn!("Expected BlockLRCache caches, but non available, skipping cache preparation");
log::warn!(
"Expected BlockLRCache caches, but non available, skipping cache preparation"
);
return;
};

let BlockCache::LR {
lr,
combo_indexes
} = next_cache else {
let BlockCache::LR { lr, combo_indexes } = next_cache else {
log::warn!("Unable to downcast cache to BlockLRCache, skipping cache preparation");
return;
};
Expand Down Expand Up @@ -263,15 +261,15 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)
}

fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)
}
Expand All @@ -280,7 +278,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
Expand Down
6 changes: 3 additions & 3 deletions src/block_neural.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?;
block_helpers::write_weights_to_buf(&self.weights_optimizer, output_bufwriter, false)?;
Expand All @@ -440,7 +440,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?;
block_helpers::read_weights_from_buf(&mut self.weights_optimizer, input_bufreader, false)?;
Expand All @@ -466,7 +466,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
Expand Down
138 changes: 138 additions & 0 deletions src/buffer_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use flate2::read::MultiGzDecoder;
use std::fs::File;
use std::io;
use std::io::BufRead;
use std::path::Path;
use zstd::stream::read::Decoder as ZstdDecoder;

pub fn create_buffered_input(input_filename: &str) -> Box<dyn BufRead> {
// Handler for different (or no) compression types

let input = File::open(input_filename).expect("Could not open the input file.");

let input_format = Path::new(&input_filename)
.extension()
.and_then(|ext| ext.to_str())
.expect("Failed to get the file extension.");

match input_format {
"gz" => {
let gz_decoder = MultiGzDecoder::new(input);
let reader = io::BufReader::new(gz_decoder);
Box::new(reader)
}
"zst" => {
let zstd_decoder = ZstdDecoder::new(input).unwrap();
let reader = io::BufReader::new(zstd_decoder);
Box::new(reader)
}
"vw" => {
let reader = io::BufReader::new(input);
Box::new(reader)
}
_ => {
panic!("Please specify a valid input format (.vw, .zst, .gz)");
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::{self, Read, Write};
use tempfile::Builder as TempFileBuilder;
use tempfile::NamedTempFile;
use zstd::stream::Encoder as ZstdEncoder;

fn create_temp_file_with_contents(
extension: &str,
contents: &[u8],
) -> io::Result<NamedTempFile> {
let temp_file = TempFileBuilder::new()
.suffix(&format!(".{}", extension))
.tempfile()?;
temp_file.as_file().write_all(contents)?;
Ok(temp_file)
}

fn create_gzipped_temp_file(contents: &[u8]) -> io::Result<NamedTempFile> {
let temp_file = TempFileBuilder::new().suffix(".gz").tempfile()?;
let gz = GzEncoder::new(Vec::new(), Compression::default());
let mut gz_writer = io::BufWriter::new(gz);
gz_writer.write_all(contents)?;
let gz = gz_writer.into_inner()?.finish()?;
temp_file.as_file().write_all(&gz)?;
Ok(temp_file)
}

fn create_zstd_temp_file(contents: &[u8]) -> io::Result<NamedTempFile> {
let temp_file = TempFileBuilder::new().suffix(".zst").tempfile()?;
let mut zstd_encoder = ZstdEncoder::new(Vec::new(), 1)?;
zstd_encoder.write_all(contents)?;
let encoded_data = zstd_encoder.finish()?;
temp_file.as_file().write_all(&encoded_data)?;
Ok(temp_file)
}

// Test for uncompressed file ("vw" extension)
#[test]
fn test_uncompressed_file() {
let contents = b"Sample text for uncompressed file.";
let temp_file =
create_temp_file_with_contents("vw", contents).expect("Failed to create temp file");
let mut reader = create_buffered_input(temp_file.path().to_str().unwrap());

let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.expect("Failed to read from the reader");
assert_eq!(
buffer, contents,
"Contents did not match for uncompressed file."
);
}

// Test for gzipped files ("gz" extension)
#[test]
fn test_gz_compressed_file() {
let contents = b"Sample text for gzipped file.";
let temp_file =
create_gzipped_temp_file(contents).expect("Failed to create gzipped temp file");
let mut reader = create_buffered_input(temp_file.path().to_str().unwrap());

let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.expect("Failed to read from the reader");
assert_eq!(buffer, contents, "Contents did not match for gzipped file.");
}

// Test for zstd compressed files ("zst" extension)
#[test]
fn test_zstd_compressed_file() {
let contents = b"Sample text for zstd compressed file.";
let temp_file = create_zstd_temp_file(contents).expect("Failed to create zstd temp file");
let mut reader = create_buffered_input(temp_file.path().to_str().unwrap());

let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.expect("Failed to read from the reader");
assert_eq!(
buffer, contents,
"Contents did not match for zstd compressed file."
);
}

// Test for unsupported file format
#[test]
#[should_panic(expected = "Please specify a valid input format (.vw, .zst, .gz)")]
fn test_unsupported_file_format() {
let contents = b"Some content";
let temp_file =
create_temp_file_with_contents("txt", contents).expect("Failed to create temp file");
let _reader = create_buffered_input(temp_file.path().to_str().unwrap());
}
}
1 change: 0 additions & 1 deletion src/feature_transform_implementations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::feature_transform_executor::{
use crate::feature_transform_parser;
use crate::vwmap::{NamespaceDescriptor, NamespaceFormat, NamespaceType};


// -------------------------------------------------------------------
// TransformerBinner - A basic binner
// It can take any function as a binning function f32 -> f32. Then output is rounded to integer
Expand Down
Loading

0 comments on commit cd270c5

Please sign in to comment.