From f1574a32370568f100bbb952745173f49b7b3d05 Mon Sep 17 00:00:00 2001 From: Radonirinaunimi Date: Fri, 18 Oct 2024 22:58:12 +0200 Subject: [PATCH] Make Generalized convolution works --- pineappl_py/src/grid.rs | 31 ++++++++++++++++++++++++++++--- pineappl_py/tests/test_grid.py | 8 ++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pineappl_py/src/grid.rs b/pineappl_py/src/grid.rs index 0e2aaa6c..f63d5eaf 100644 --- a/pineappl_py/src/grid.rs +++ b/pineappl_py/src/grid.rs @@ -314,13 +314,38 @@ impl PyGrid { xi: Option>, py: Python<'py>, ) -> Bound<'py, PyArray1> { - // Closure for alphas function let mut alphas = |q2: f64| { let result: f64 = alphas.call1(py, (q2,)).unwrap().extract(py).unwrap(); result }; - todo!() + let mut xfx_funcs: Vec<_> = xfxs + .iter() + .map(|xfx| { + move |id: i32, x: f64, q2: f64| { + xfx.call1(py, (id, x, q2)).unwrap().extract(py).unwrap() + } + }) + .collect(); + + let mut lumi_cache = ConvolutionCache::new( + pdg_convs.into_iter().map(|pdg| pdg.conv.clone()).collect(), + xfx_funcs + .iter_mut() + .map(|fx| fx as &mut dyn FnMut(i32, f64, f64) -> f64) + .collect(), + &mut alphas, + ); + + self.grid + .convolve( + &mut lumi_cache, + &order_mask.unwrap_or_default(), + &bin_indices.unwrap_or_default(), + &channel_mask.unwrap_or_default(), + &xi.unwrap_or_else(|| vec![(1.0, 1.0, 0.0)]), + ) + .into_pyarray_bound(py) } /// Collect information for convolution with an evolution operator. @@ -684,7 +709,7 @@ impl PyGrid { /// bin_indices : numpy.ndarray[int] /// list of indices of bins to removed pub fn delete_bins(&mut self, bin_indices: Vec) { - self.grid.delete_bins(&bin_indices) + self.grid.delete_bins(&bin_indices); } } diff --git a/pineappl_py/tests/test_grid.py b/pineappl_py/tests/test_grid.py index fc4ba013..875a2ea9 100644 --- a/pineappl_py/tests/test_grid.py +++ b/pineappl_py/tests/test_grid.py @@ -172,6 +172,14 @@ def test_convolve_with_two(self): ), [2**3 * v, 0.0], ) + np.testing.assert_allclose( + g.convolve( + pdg_convs=[CONVOBJECT, CONVOBJECT], + xfxs=[lambda pid, x, q2: 1.0, lambda pid, x, q2: 1.0], + alphas=lambda q2: 2.0, + ), + [2**3 * v, 0.0], + ) # Test using the generalized convolution np.testing.assert_allclose( g.convolve_with_two( pdg_conv1=CONVOBJECT,