Skip to content

Commit

Permalink
rust: Load errors
Browse files Browse the repository at this point in the history
  • Loading branch information
felixhekhorn committed Aug 21, 2024
1 parent 234a5e5 commit 1177a57
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
20 changes: 11 additions & 9 deletions crates/dekoder/src/inventory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use lz4_flex::frame::FrameDecoder;
use ndarray_npy::NpzReader;
use std::collections::HashMap;
use std::ffi::{OsStr, OsString};
use std::ffi::OsString;
use std::fs::{read_dir, read_to_string, File};
use std::io::Cursor;
use std::path::PathBuf;
Expand All @@ -27,10 +27,7 @@ impl<K: Eq + for<'a> TryFrom<&'a Yaml, Error = EKOError>> Inventory<K> {
for entry in read_dir(&self.path)? {
// is header file?
let entry = entry?.path();
if !entry
.extension()
.is_some_and(|ext| ext.eq(OsStr::new(HEADER_EXT)))
{
if !entry.extension().is_some_and(|ext| ext.eq(HEADER_EXT)) {
continue;
}
// read
Expand Down Expand Up @@ -68,17 +65,22 @@ impl<K: Eq + for<'a> TryFrom<&'a Yaml, Error = EKOError>> Inventory<K> {
.iter()
.find(|it| (it.1).eq(k))
.ok_or(EKOError::KeyError("because it was not found".to_owned()))?;
// TODO determine if errors are available
let p = self.path.join(k.0).with_extension("npz.lz4");
// Read npz.lz4
let mut reader = FrameDecoder::new(File::open(&p)?);
let mut buffer = Vec::new();
std::io::copy(&mut reader, &mut buffer)?;
let mut npz = NpzReader::new(Cursor::new(buffer))
.map_err(|_| EKOError::OperatorLoadError(p.to_owned()))?;
let op = Some(
NpzReader::new(Cursor::new(buffer))
.map_err(|_| EKOError::OperatorLoadError(p.to_owned()))?
.by_name("operator.npy")
npz.by_name("operator.npy")
.map_err(|_| EKOError::OperatorLoadError(p.to_owned()))?,
);
Ok(Operator { op })
let err = Some(
npz.by_name("error.npy")
.map_err(|_| EKOError::OperatorLoadError(p.to_owned()))?,
);
Ok(Operator { op, err })
}
}
7 changes: 6 additions & 1 deletion crates/dekoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ pub type Result<T> = std::result::Result<T, EKOError>;
pub struct Operator {
/// The actual rank 4 tensor.
pub op: Option<Array4<f64>>,
/// The associated element-by-element error.
pub err: Option<Array4<f64>>,
}

impl Default for Operator {
/// Empty initializer.
fn default() -> Self {
Self { op: None }
Self {
op: None,
err: None,
}
}
}
9 changes: 6 additions & 3 deletions crates/dekoder/tests/test_load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ fn load_operator() {
scale: 10000.,
nf: 4,
};
let op = eko.load_operator(&ep).unwrap();
assert!(op.op.is_some());
assert!(op.op.unwrap().dim().0 > 0);
let operator = eko.load_operator(&ep).unwrap();
assert!(operator.op.is_some());
assert!(operator.err.is_some());
let op = operator.op.unwrap();
assert!(op.dim().0 > 0);
assert!(op.dim().0 == operator.err.unwrap().dim().0);
}

0 comments on commit 1177a57

Please sign in to comment.