diff --git a/Cargo.lock b/Cargo.lock index f0525ac8..9e8791dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1447,6 +1447,12 @@ dependencies = [ "digest", ] +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + [[package]] name = "spin" version = "0.9.8" diff --git a/pineappl/src/evolution.rs b/pineappl/src/evolution.rs index 4c19e61c..d11ce136 100644 --- a/pineappl/src/evolution.rs +++ b/pineappl/src/evolution.rs @@ -11,7 +11,10 @@ use float_cmp::approx_eq; use itertools::izip; use itertools::Itertools; use ndarray::linalg; -use ndarray::{s, Array1, Array2, Array3, ArrayD, ArrayView1, ArrayView4, Axis, Ix1, Ix2}; +use ndarray::{ + s, Array1, Array2, Array3, ArrayD, ArrayView1, ArrayView4, ArrayViewD, ArrayViewMutD, Axis, + Ix1, Ix2, +}; use std::iter; /// This structure captures the information needed to create an evolution kernel operator (EKO) for @@ -462,7 +465,7 @@ pub(crate) fn evolve_slice_with_many( .map(|ops| (fk_table, ops)) }) { - general_tensor_mul(*factor, &array, &ops, fk_table); + general_tensor_mul(*factor, array.view(), &ops, &mut fk_table.view_mut()); } } } @@ -496,11 +499,12 @@ pub(crate) fn evolve_slice_with_many( fn general_tensor_mul( factor: f64, - array: &ArrayD, + array: ArrayViewD, ops: &[&Array2], - fk_table: &mut ArrayD, + fk_table: &mut ArrayViewMutD, ) { match array.shape().len() { + 0 => unreachable!(), 1 => { let array = array.view().into_dimensionality::().unwrap(); let mut fk_table = fk_table.view_mut().into_dimensionality::().unwrap(); @@ -516,7 +520,17 @@ fn general_tensor_mul( // fk_table += factor * ops[0] * tmp linalg::general_mat_mul(factor, ops[0], &tmp, 1.0, &mut fk_table); } - // TODO: generalize this to n dimensions - _ => unimplemented!(), + _ => { + let (ops_0, ops_dm1) = ops.split_first().unwrap(); + + for (mut fk_table_i, ops_0_i) in fk_table + .axis_iter_mut(Axis(0)) + .zip(ops_0.axis_iter(Axis(0))) + { + for (array_j, ops_0_ij) in array.axis_iter(Axis(0)).zip(ops_0_i.iter()) { + general_tensor_mul(factor * ops_0_ij, array_j, &ops_dm1, &mut fk_table_i); + } + } + } } }