diff --git a/src/dilax/effect.py b/src/dilax/effect.py index 2095b3a..e234199 100644 --- a/src/dilax/effect.py +++ b/src/dilax/effect.py @@ -95,7 +95,6 @@ def __init__( self.up = up # +1 sigma self.down = down # -1 sigma - @eqx.filter_jit def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array: factor = sf dx_sum = self.up + self.down - 2 * sumw @@ -121,8 +120,10 @@ def constraint(self) -> HashablePDF: def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: sf = parameter.value - # clip, no negative values are allowed - return jnp.maximum((sumw + self.vshift(sf=sf, sumw=sumw)) / sumw, 0.0) + shift = self.vshift(sf=sf, sumw=sumw) + # handle zeros, see: https://github.com/google/jax/issues/5039 + x = jnp.where(sumw == 0.0, 1.0, sumw) + return jnp.where(sumw == 0.0, 1.0, (x + shift) / x) class lnN(Effect): diff --git a/src/dilax/modifier.py b/src/dilax/modifier.py index 81c2f06..c99ba4b 100644 --- a/src/dilax/modifier.py +++ b/src/dilax/modifier.py @@ -173,6 +173,8 @@ class staterror(ModifierBase): """ Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier. + *Caution*: The instantiation of a `staterror` is not compatible with JAX-transformations (e.g. `jax.jit`)! + Example: .. code-block:: python @@ -188,7 +190,7 @@ class staterror(ModifierBase): # all bins with bin content below 10 (threshold) are treated as poisson, else gauss modify = dlx.staterror( - parameters=[p1, p2, p3], + parameters={1: p1, 2: p2, 3: p3}, sumw=hist, sumw2=hist, threshold=10.0, @@ -272,15 +274,7 @@ def _mod( _poisson_mod = partial(_mod, effect=poisson) _gauss_mod = partial(_mod, effect=gauss) - # if all true in mask use poisson - if jnp.all(self.mask): - return jax.vmap(_poisson_mod)(values, _widths, idxs) - - # if all false in mask use gauss - if jnp.all(~self.mask): - return jax.vmap(_gauss_mod)(values, _widths, idxs) - - # if mixed use jnp.where + # apply return jnp.where( self.mask, jax.vmap(_poisson_mod)(values, _widths, idxs), @@ -290,7 +284,9 @@ def _mod( class autostaterrors(eqx.Module): class Mode(eqx.Enumeration): - poisson = "Poisson per process and bin" + barlow_beeston_full = ( + "Barlow-Beeston (full) approach: Poisson per process and bin" + ) poisson_gauss = "Poisson (Gauss) per process and bin if sumw < (>) threshold" barlow_beeston_lite = "Barlow-Beeston (lite) approach" @@ -298,7 +294,7 @@ class Mode(eqx.Enumeration): sumw2: dict[str, jax.Array] masks: dict[str, jax.Array] threshold: float - mode: str = eqx.field(static=True) + mode: str key_template: str = eqx.field(static=True) def __init__( @@ -340,7 +336,7 @@ def prepare( Helper to automatically create parameters used by `staterror` for the initialisation of a `dlx.Model`. - _Caution_: This function is not compatible with JAX-transformations (e.g. `jax.jit`)! + *Caution*: This function is not compatible with JAX-transformations (e.g. `jax.jit`)! Example: @@ -364,7 +360,7 @@ def prepare( sumw=sumw, sumw2=sumw2, threshold=10.0, - mode=dlx.autostaterrors.Mode.poisson, + mode=dlx.autostaterrors.Mode.barlow_beeston_full, ) parameters, staterrors = auto.prepare() @@ -387,6 +383,7 @@ def prepare( parameters: dict[str, dict[str, Parameter]] = {} staterrors: dict[str, dict[str, eqx.Partial]] = {} + for process, _sumw in self.sumw.items(): key = self.key_template.format(process=process) process_parameters = parameters[key] = {} @@ -405,11 +402,13 @@ def prepare( "sumw2": self.sumw2[process], "threshold": self.threshold, } - if self.mode == self.Mode.poisson: + if self.mode == self.Mode.barlow_beeston_full: kwargs["threshold"] = jnp.inf # inf -> always poisson elif self.mode == self.Mode.barlow_beeston_lite: kwargs["sumw"] = jnp.where( - mask, _sumw, sum(jax.tree_util.tree_leaves(self.sumw)) + mask, + _sumw, + sum(jax.tree_util.tree_leaves(self.sumw)), ) kwargs["sumw2"] = jnp.where( mask,