diff --git a/src/block_ffm.rs b/src/block_ffm.rs index ae34e7e8..fab305ae 100644 --- a/src/block_ffm.rs +++ b/src/block_ffm.rs @@ -1937,7 +1937,7 @@ mod tests { contra_field_index: mi.ffm_k, }]); assert_eq!(spredict2(&mut bg, &fb, &mut pb), 0.5); - assert_eq!(slearn2(&mut bg, &fb, &mut pb, true), 0.5); + assert_eq!(slearn2(&mut bg, &fb, &mut pb, true), 0.62245935); } #[test] diff --git a/src/main.rs b/src/main.rs index 4555bbc1..a6fd27d2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,7 +50,7 @@ fn main() { fn create_buffered_input(input_filename: &str) -> Box { // Handler for different (or no) compression types - + let input = File::open(input_filename).expect("Could not open the input file."); let input_format = Path::new(&input_filename) @@ -322,3 +322,105 @@ fn main_fw_loop() -> Result<(), Box> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use flate2::write::GzEncoder; + use flate2::Compression; + use std::fs::File; + use std::io::{self, BufReader, Read, Write}; + use tempfile::Builder as TempFileBuilder; + use tempfile::NamedTempFile; + use zstd::stream::Encoder as ZstdEncoder; + + fn create_temp_file_with_contents( + extension: &str, + contents: &[u8], + ) -> io::Result { + let temp_file = TempFileBuilder::new() + .suffix(&format!(".{}", extension)) + .tempfile()?; + temp_file.as_file().write_all(contents)?; + Ok(temp_file) + } + + fn create_gzipped_temp_file(contents: &[u8]) -> io::Result { + let temp_file = TempFileBuilder::new().suffix(".gz").tempfile()?; + let gz = GzEncoder::new(Vec::new(), Compression::default()); + let mut gz_writer = io::BufWriter::new(gz); + gz_writer.write_all(contents)?; + let gz = gz_writer.into_inner()?.finish()?; + temp_file.as_file().write_all(&gz)?; + Ok(temp_file) + } + + fn create_zstd_temp_file(contents: &[u8]) -> io::Result { + let temp_file = TempFileBuilder::new().suffix(".zst").tempfile()?; + let mut zstd_encoder = ZstdEncoder::new(Vec::new(), 1)?; + zstd_encoder.write_all(contents)?; + let encoded_data = zstd_encoder.finish()?; + temp_file.as_file().write_all(&encoded_data)?; + Ok(temp_file) + } + + // Test for uncompressed file ("vw" extension) + #[test] + fn test_uncompressed_file() { + let contents = b"Sample text for uncompressed file."; + let temp_file = + create_temp_file_with_contents("vw", contents).expect("Failed to create temp file"); + let mut reader = create_buffered_input(temp_file.path().to_str().unwrap()); + + let mut buffer = Vec::new(); + reader + .read_to_end(&mut buffer) + .expect("Failed to read from the reader"); + assert_eq!( + buffer, contents, + "Contents did not match for uncompressed file." + ); + } + + // Test for gzipped files ("gz" extension) + #[test] + fn test_gz_compressed_file() { + let contents = b"Sample text for gzipped file."; + let temp_file = + create_gzipped_temp_file(contents).expect("Failed to create gzipped temp file"); + let mut reader = create_buffered_input(temp_file.path().to_str().unwrap()); + + let mut buffer = Vec::new(); + reader + .read_to_end(&mut buffer) + .expect("Failed to read from the reader"); + assert_eq!(buffer, contents, "Contents did not match for gzipped file."); + } + + // Test for zstd compressed files ("zst" extension) + #[test] + fn test_zstd_compressed_file() { + let contents = b"Sample text for zstd compressed file."; + let temp_file = create_zstd_temp_file(contents).expect("Failed to create zstd temp file"); + let mut reader = create_buffered_input(temp_file.path().to_str().unwrap()); + + let mut buffer = Vec::new(); + reader + .read_to_end(&mut buffer) + .expect("Failed to read from the reader"); + assert_eq!( + buffer, contents, + "Contents did not match for zstd compressed file." + ); + } + + // Test for unsupported file format + #[test] + #[should_panic(expected = "Please specify a valid input format (.vw, .zst, .gz)")] + fn test_unsupported_file_format() { + let contents = b"Some content"; + let temp_file = + create_temp_file_with_contents("txt", contents).expect("Failed to create temp file"); + let _reader = create_buffered_input(temp_file.path().to_str().unwrap()); + } +}