Skip to content

Commit

Permalink
fix 0 bins for shape effect; properly rename barlow beeston full mode
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Dec 14, 2023
1 parent 46e7356 commit 2bb2255
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
7 changes: 4 additions & 3 deletions src/dilax/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(
self.up = up # +1 sigma
self.down = down # -1 sigma

@eqx.filter_jit
def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array:
factor = sf
dx_sum = self.up + self.down - 2 * sumw
Expand All @@ -121,8 +120,10 @@ def constraint(self) -> HashablePDF:

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
sf = parameter.value
# clip, no negative values are allowed
return jnp.maximum((sumw + self.vshift(sf=sf, sumw=sumw)) / sumw, 0.0)
shift = self.vshift(sf=sf, sumw=sumw)
# handle zeros, see: https://github.com/google/jax/issues/5039
x = jnp.where(sumw == 0.0, 1.0, sumw)
return jnp.where(sumw == 0.0, 1.0, (x + shift) / x)


class lnN(Effect):
Expand Down
31 changes: 15 additions & 16 deletions src/dilax/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class staterror(ModifierBase):
"""
Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier.
*Caution*: The instantiation of a `staterror` is not compatible with JAX-transformations (e.g. `jax.jit`)!
Example:
.. code-block:: python
Expand All @@ -188,7 +190,7 @@ class staterror(ModifierBase):
# all bins with bin content below 10 (threshold) are treated as poisson, else gauss
modify = dlx.staterror(
parameters=[p1, p2, p3],
parameters={1: p1, 2: p2, 3: p3},
sumw=hist,
sumw2=hist,
threshold=10.0,
Expand Down Expand Up @@ -272,15 +274,7 @@ def _mod(
_poisson_mod = partial(_mod, effect=poisson)
_gauss_mod = partial(_mod, effect=gauss)

# if all true in mask use poisson
if jnp.all(self.mask):
return jax.vmap(_poisson_mod)(values, _widths, idxs)

# if all false in mask use gauss
if jnp.all(~self.mask):
return jax.vmap(_gauss_mod)(values, _widths, idxs)

# if mixed use jnp.where
# apply
return jnp.where(
self.mask,
jax.vmap(_poisson_mod)(values, _widths, idxs),
Expand All @@ -290,15 +284,17 @@ def _mod(

class autostaterrors(eqx.Module):
class Mode(eqx.Enumeration):
poisson = "Poisson per process and bin"
barlow_beeston_full = (
"Barlow-Beeston (full) approach: Poisson per process and bin"
)
poisson_gauss = "Poisson (Gauss) per process and bin if sumw < (>) threshold"
barlow_beeston_lite = "Barlow-Beeston (lite) approach"

sumw: dict[str, jax.Array]
sumw2: dict[str, jax.Array]
masks: dict[str, jax.Array]
threshold: float
mode: str = eqx.field(static=True)
mode: str
key_template: str = eqx.field(static=True)

def __init__(
Expand Down Expand Up @@ -340,7 +336,7 @@ def prepare(
Helper to automatically create parameters used by `staterror`
for the initialisation of a `dlx.Model`.
_Caution_: This function is not compatible with JAX-transformations (e.g. `jax.jit`)!
*Caution*: This function is not compatible with JAX-transformations (e.g. `jax.jit`)!
Example:
Expand All @@ -364,7 +360,7 @@ def prepare(
sumw=sumw,
sumw2=sumw2,
threshold=10.0,
mode=dlx.autostaterrors.Mode.poisson,
mode=dlx.autostaterrors.Mode.barlow_beeston_full,
)
parameters, staterrors = auto.prepare()
Expand All @@ -387,6 +383,7 @@ def prepare(

parameters: dict[str, dict[str, Parameter]] = {}
staterrors: dict[str, dict[str, eqx.Partial]] = {}

for process, _sumw in self.sumw.items():
key = self.key_template.format(process=process)
process_parameters = parameters[key] = {}
Expand All @@ -405,11 +402,13 @@ def prepare(
"sumw2": self.sumw2[process],
"threshold": self.threshold,
}
if self.mode == self.Mode.poisson:
if self.mode == self.Mode.barlow_beeston_full:
kwargs["threshold"] = jnp.inf # inf -> always poisson
elif self.mode == self.Mode.barlow_beeston_lite:
kwargs["sumw"] = jnp.where(
mask, _sumw, sum(jax.tree_util.tree_leaves(self.sumw))
mask,
_sumw,
sum(jax.tree_util.tree_leaves(self.sumw)),
)
kwargs["sumw2"] = jnp.where(
mask,
Expand Down

0 comments on commit 2bb2255

Please sign in to comment.