Skip to content

Commit

Permalink
Generalize general_tensor_mul to arbitrary dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
t7phy committed Dec 18, 2024
1 parent 20c263a commit d13bb77
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 20 additions & 6 deletions pineappl/src/evolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use float_cmp::approx_eq;
use itertools::izip;
use itertools::Itertools;
use ndarray::linalg;
use ndarray::{s, Array1, Array2, Array3, ArrayD, ArrayView1, ArrayView4, Axis, Ix1, Ix2};
use ndarray::{
s, Array1, Array2, Array3, ArrayD, ArrayView1, ArrayView4, ArrayViewD, ArrayViewMutD, Axis,
Ix1, Ix2,
};
use std::iter;

/// This structure captures the information needed to create an evolution kernel operator (EKO) for
Expand Down Expand Up @@ -462,7 +465,7 @@ pub(crate) fn evolve_slice_with_many(
.map(|ops| (fk_table, ops))
})
{
general_tensor_mul(*factor, &array, &ops, fk_table);
general_tensor_mul(*factor, array.view(), &ops, &mut fk_table.view_mut());
}
}
}
Expand Down Expand Up @@ -496,11 +499,12 @@ pub(crate) fn evolve_slice_with_many(

fn general_tensor_mul(
factor: f64,
array: &ArrayD<f64>,
array: ArrayViewD<f64>,
ops: &[&Array2<f64>],
fk_table: &mut ArrayD<f64>,
fk_table: &mut ArrayViewMutD<f64>,
) {
match array.shape().len() {
0 => unreachable!(),
1 => {
let array = array.view().into_dimensionality::<Ix1>().unwrap();
let mut fk_table = fk_table.view_mut().into_dimensionality::<Ix1>().unwrap();
Expand All @@ -516,7 +520,17 @@ fn general_tensor_mul(
// fk_table += factor * ops[0] * tmp
linalg::general_mat_mul(factor, ops[0], &tmp, 1.0, &mut fk_table);
}
// TODO: generalize this to n dimensions
_ => unimplemented!(),
_ => {
let (ops_0, ops_dm1) = ops.split_first().unwrap();

for (mut fk_table_i, ops_0_i) in fk_table
.axis_iter_mut(Axis(0))
.zip(ops_0.axis_iter(Axis(0)))
{
for (array_j, ops_0_ij) in array.axis_iter(Axis(0)).zip(ops_0_i.iter()) {
general_tensor_mul(factor * ops_0_ij, array_j, &ops_dm1, &mut fk_table_i);
}
}
}
}
}

0 comments on commit d13bb77

Please sign in to comment.