diff --git a/pineappl_py/src/import_subgrid.rs b/pineappl_py/src/import_subgrid.rs index 6cef84e3..d3721a81 100644 --- a/pineappl_py/src/import_subgrid.rs +++ b/pineappl_py/src/import_subgrid.rs @@ -1,7 +1,7 @@ //! PyPackedSubgrid* interface. use super::subgrid::PySubgridEnum; -use numpy::PyReadonlyArray3; +use numpy::{PyReadonlyArray2, PyReadonlyArray3}; use pineappl::import_subgrid::ImportSubgridV1; use pineappl::packed_array::PackedArray; use pyo3::prelude::*; @@ -18,6 +18,8 @@ pub struct PyImportSubgridV1 { impl PyImportSubgridV1 { /// Constructor. /// + /// # Panics + /// /// Parameters /// ---------- /// array : numpy.ndarray(float) @@ -26,25 +28,58 @@ impl PyImportSubgridV1 { /// scales grid /// x1_grid : list(float) /// first momentum fraction grid - /// x2_grid : list(float) + /// x2_grid : Optional(list(float)) /// second momentum fraction grid #[new] - pub fn new( - array: PyReadonlyArray3, + #[must_use] + pub fn new<'py>( + array: PyObject, scales: Vec, x1_grid: Vec, - x2_grid: Vec, + x2_grid: Option>, + py: Python<'py>, ) -> Self { - let node_values: Vec> = vec![scales, x1_grid, x2_grid]; + // let node_values: Vec> = vec![scales, x1_grid, x2_grid]; + // let mut sparse_array: PackedArray = + // PackedArray::new(node_values.iter().map(Vec::len).collect()); + + // for ((iscale, ix1, ix2), value) in array + // .as_array() + // .indexed_iter() + // .filter(|((_, _, _), value)| **value != 0.0) + // { + // sparse_array[[iscale, ix1, ix2]] = *value; + // } + + // Self { + // import_subgrid: ImportSubgridV1::new(sparse_array, node_values), + // } + let mut node_values: Vec> = vec![scales, x1_grid]; + + if let Some(x2) = x2_grid { + node_values.push(x2); + } let mut sparse_array: PackedArray = PackedArray::new(node_values.iter().map(Vec::len).collect()); - for ((iscale, ix1, ix2), value) in array - .as_array() - .indexed_iter() - .filter(|((_, _, _), value)| **value != 0.0) - { - sparse_array[[iscale, ix1, ix2]] = *value; + if sparse_array.shape().to_vec().len() == 3 { + let array_3d: PyReadonlyArray3 = array.extract(py).unwrap(); + for ((iscale, ix1, ix2), value) in array_3d + .as_array() + .indexed_iter() + .filter(|((_, _, _), value)| **value != 0.0) + { + sparse_array[[iscale, ix1, ix2]] = *value; + } + } else { + let array_2d: PyReadonlyArray2 = array.extract(py).unwrap(); + for ((iscale, ix1), value) in array_2d + .as_array() + .indexed_iter() + .filter(|((_, _), value)| **value != 0.0) + { + sparse_array[[iscale, ix1]] = *value; + } } Self { diff --git a/pineappl_py/tests/test_fk_table.py b/pineappl_py/tests/test_fk_table.py index fdde8d4f..ac61e166 100644 --- a/pineappl_py/tests/test_fk_table.py +++ b/pineappl_py/tests/test_fk_table.py @@ -86,10 +86,9 @@ def test_convolve(self): xs = np.linspace(0.5, 1.0, 5) vs = xs.copy() subgrid = ImportSubgridV1( - vs[np.newaxis, :, np.newaxis], + vs[np.newaxis, :], # DIS shape: (len(q2), len(x_grid)) np.array([90.0]), xs, - np.array([1.0]), ) g.set_subgrid(0, 0, 0, subgrid.into())