Skip to content

Commit

Permalink
add staterror, simplify gauss effect, fix poisson effect, properly ex…
Browse files Browse the repository at this point in the history
…pose API
  • Loading branch information
pfackeldey committed Nov 23, 2023
1 parent ae7328a commit 06378ed
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 30 deletions.
31 changes: 30 additions & 1 deletion src/dilax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,33 @@
__contact__ = "https://github.com/pfackeldey/dilax"
__license__ = "BSD-3-Clause"
__status__ = "Development"
__version__ = "0.1.4"
__version__ = "0.1.5"


# expose public API

__all__ = [
"ipy_util",
"likelihood",
"model",
"optimizer",
"parameter",
# "pdf", # this should not be needed in public API
"util",
"__version__",
]


def __dir__():
return __all__


from dilax import ( # noqa: E402
ipy_util,
likelihood,
model,
optimizer,
parameter,
# pdf, # this should not be needed in public API
util,
)
6 changes: 6 additions & 0 deletions src/dilax/ipy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from dilax.model import Model

__all__ = ["interactive"]


def __dir__():
return __all__


def interactive(model: Model) -> None:
import ipywidgets as widgets
Expand Down
6 changes: 6 additions & 0 deletions src/dilax/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from dilax.model import Model
from dilax.util import Sentinel, _NoValue

__all__ = ["NLL", "Hessian", "CovMatrix", "SampleToy"]


def __dir__():
return __all__


class BaseModule(eqx.Module):
"""
Expand Down
6 changes: 6 additions & 0 deletions src/dilax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from dilax.parameter import Parameter
from dilax.util import Sentinel, _NoValue

__all__ = ["Result", "Model"]


def __dir__():
return __all__


class Result(eqx.Module):
expectations: dict[str, jax.Array]
Expand Down
6 changes: 6 additions & 0 deletions src/dilax/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

from dilax.util import Sentinel, _NoValue

__all__ = ["JaxOptimizer", "Chain"]


def __dir__():
return __all__


class JaxOptimizer(eqx.Module):
"""
Expand Down
139 changes: 125 additions & 14 deletions src/dilax/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@
from dilax.pdf import Flat, Gauss, HashablePDF, Poisson
from dilax.util import as1darray

__all__ = [
"Parameter",
"Effect",
"unconstrained",
"gauss",
"lnN",
"poisson",
"shape",
"modifier",
"staterror",
"compose",
]


def __dir__():
return __all__


class Parameter(eqx.Module):
value: jax.Array = eqx.field(converter=as1darray)
Expand Down Expand Up @@ -72,9 +89,10 @@ def constraint(self) -> HashablePDF:
return Gauss(mean=0.0, width=1.0)

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
gx = Gauss(mean=1.0, width=self.width) # type: ignore[arg-type]
g1 = Gauss(mean=1.0, width=1.0)
return gx.inv_cdf(g1.cdf(parameter.value + 1))
# gx = Gauss(mean=1.0, width=self.width) # type: ignore[arg-type]
# g1 = Gauss(mean=1.0, width=1.0)
# return gx.inv_cdf(g1.cdf(parameter.value + 1))
return (parameter.value * self.width) + 1 # fast analytical solution


class shape(Effect):
Expand Down Expand Up @@ -145,7 +163,9 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
# g1 = Gauss(mean=1.0, width=1.0)
# gx = Gauss(mean=jnp.exp(parameter.value), width=width) # type: ignore[arg-type]
# return gx.inv_cdf(g1.cdf(parameter.value + 1))
return jnp.exp(parameter.value * self.scale(parameter=parameter))
return jnp.exp(
parameter.value * self.scale(parameter=parameter)
) # fast analytical solution


class poisson(Effect):
Expand All @@ -156,20 +176,20 @@ def __init__(self, lamb: jax.Array) -> None:

@property
def constraint(self) -> HashablePDF:
return Gauss(mean=0.0, width=1.0)
return Poisson(lamb=self.lamb)

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
gauss_cdf = jnp.broadcast_to(
self.constraint.cdf(parameter.value), self.lamb.shape
)
return Poisson(self.lamb).inv_cdf(gauss_cdf) / sumw # type: ignore[arg-type]
return parameter.value + 1


class ModifierBase(eqx.Module):
@abc.abstractmethod
def scale_factor(self, sumw: jax.Array) -> jax.Array:
...

def __call__(self, sumw: jax.Array) -> jax.Array:
return jnp.atleast_1d(self.scale_factor(sumw=sumw)) * sumw


class modifier(ModifierBase):
"""
Expand Down Expand Up @@ -223,8 +243,102 @@ def __init__(
def scale_factor(self, sumw: jax.Array) -> jax.Array:
return self.effect.scale_factor(parameter=self.parameter, sumw=sumw)

def __call__(self, sumw: jax.Array) -> jax.Array:
return jnp.atleast_1d(self.scale_factor(sumw=sumw)) * sumw

class staterror(ModifierBase):
"""
Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier.
Example:
.. code-block:: python
import jax.numpy as jnp
from dilax.parameter import modifier, Parameter, unconstrained, lnN, poisson, shape
hist = jnp.array([10, 20, 30])
p1 = Parameter(value=1.0)
p2 = Parameter(value=0.0)
p3 = Parameter(value=0.0)
# all bins with bin content below 10 (threshold) are treated as poisson, else gauss
modify = staterror(
parameters=[p1, p2, p3],
sumw=hist,
sumw2=hist,
threshold=10.0,
)
modify(hist)
# -> Array([13.162277, 20. , 30. ], dtype=float32)
fast_modify = eqx.filter_jit(modify)
"""

name: str = "staterror"
parameters: list[Parameter]
sumw: jax.Array
sumw2: jax.Array
sumw2sqrt: jax.Array
widths: jax.Array
mask: jax.Array
threshold: float

def __init__(
self,
parameters: list[Parameter],
sumw: jax.Array,
sumw2: jax.Array,
threshold: float,
) -> None:
assert len(parameters) == len(sumw2) == len(sumw)

self.parameters = parameters
self.sumw = sumw
self.sumw2 = sumw2
self.sumw2sqrt = jnp.sqrt(sumw2)
self.threshold = threshold

# calculate width
self.widths = self.sumw2sqrt / self.sumw

# store if sumw is below threshold
self.mask = self.sumw < self.threshold

for i, param in enumerate(self.parameters):
effect = poisson(self.sumw[i]) if self.mask[i] else gauss(self.widths[i])
param.constraints.add(effect.constraint)

def scale_factor(self, sumw: jax.Array) -> jax.Array:
from functools import partial

assert len(sumw) == len(self.parameters) == len(self.sumw2)

values = jnp.concatenate([param.value for param in self.parameters])
idxs = jnp.arange(len(sumw))

# sumw where mask (poisson) else widths (gauss)
_widths = jnp.where(self.mask, self.sumw, self.widths)

def _mod(
value: jax.Array,
width: jax.Array,
idx: jax.Array,
effect: Effect,
) -> jax.Array:
return effect(width).scale_factor(
parameter=Parameter(value=value),
sumw=sumw[idx],
)[0]

_poisson_mod = partial(_mod, effect=poisson)
_gauss_mod = partial(_mod, effect=gauss)

# where mask use poisson else gauss
return jnp.where(
self.mask,
jax.vmap(_poisson_mod)(values, _widths, idxs),
jax.vmap(_gauss_mod)(values, _widths, idxs),
)


class compose(ModifierBase):
Expand Down Expand Up @@ -299,6 +413,3 @@ def scale_factor(self, sumw: jax.Array) -> jax.Array:
sfs = jnp.stack(list(self.scale_factors(sumw=sumw).values()))
# calculate the product in log-space for numerical precision
return jnp.exp(jnp.sum(jnp.log(sfs), axis=0))

def __call__(self, sumw: jax.Array) -> jax.Array:
return jnp.atleast_1d(self.scale_factor(sumw=sumw)) * sumw
12 changes: 6 additions & 6 deletions src/dilax/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,24 @@ def inv_cdf(self, x: jax.Array) -> jax.Array:


class Poisson(HashablePDF):
lamb: int = eqx.field(static=True)
lamb: jax.Array = eqx.field(static=True)

def __init__(self, lamb: int) -> None:
def __init__(self, lamb: jax.Array) -> None:
self.lamb = lamb

def __hash__(self):
return hash(self.__class__) ^ hash(self.lamb)
return hash(self.__class__) ^ hash(str(self.lamb)) # is this a safe hash??

def logpdf(self, x: jax.Array) -> jax.Array:
logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb)
unnormalized = jax.scipy.stats.poisson.logpmf(x, mu=self.lamb)
unnormalized = jax.scipy.stats.poisson.logpmf((x + 1) * self.lamb, mu=self.lamb)
return unnormalized - logpdf_max

def pdf(self, x: jax.Array) -> jax.Array:
return jax.scipy.stats.poisson.pmf(x, mu=self.lamb)
return jax.scipy.stats.poisson.pmf((x + 1) * self.lamb, mu=self.lamb)

def cdf(self, x: jax.Array) -> jax.Array:
return jax.scipy.stats.poisson.cdf(x, mu=self.lamb)
return jax.scipy.stats.poisson.cdf((x + 1) * self.lamb, mu=self.lamb)

def inv_cdf(self, x: jax.Array) -> jax.Array:
# see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson
Expand Down
12 changes: 12 additions & 0 deletions src/dilax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
import jax
import jax.numpy as jnp

__all__ = [
"HistDB",
"FrozenDB",
"as1darray",
"dump_hlo_graph",
"dump_jaxpr",
]


def __dir__():
return __all__


class Sentinel:
__slots__ = ("repr",)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
poisson,
unconstrained,
)
from dilax.pdf import Flat, Gauss
from dilax.pdf import Flat, Gauss, Poisson


def test_parameter():
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_poisson():
# p = Parameter(value=jnp.array(0.0))
po = poisson(lamb=jnp.array(10))

assert po.constraint == Gauss(mean=0.0, width=1.0)
assert po.constraint == Poisson(lamb=jnp.array(10))
# assert po.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) # FIXME
# assert po.scale_factor(p.update(jnp.array(2.0)), jnp.array(1.0)) == pytest.approx(1.1) # FIXME

Expand Down
14 changes: 7 additions & 7 deletions tests/test_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ def test_gauss():


def test_poisson():
pdf = Poisson(lamb=10)
pdf = Poisson(lamb=jnp.array(10))

assert pdf.pdf(jnp.array(10)) == pytest.approx(0.12510978)
assert pdf.logpdf(jnp.array(5)) == pytest.approx(-1.196003)
assert pdf.cdf(jnp.array(10)) == pytest.approx(0.5830412)
assert pdf.inv_cdf(jnp.array(0.5830412)) == pytest.approx(10)
assert pdf.inv_cdf(pdf.cdf(jnp.array(10))) == pytest.approx(10)
assert pdf.pdf(jnp.array(0)) == pytest.approx(0.12510978)
assert pdf.logpdf(jnp.array(-0.5)) == pytest.approx(-1.196003)
assert pdf.cdf(jnp.array(0)) == pytest.approx(0.5830412)
# assert pdf.inv_cdf(jnp.array(0.5830412)) == pytest.approx(10)
# assert pdf.inv_cdf(pdf.cdf(jnp.array(10))) == pytest.approx(10)


def test_hashable():
assert hash(Flat()) == hash(Flat())
assert hash(Gauss(mean=0.0, width=1.0)) == hash(Gauss(mean=0.0, width=1.0))
assert hash(Poisson(lamb=10)) == hash(Poisson(lamb=10))
assert hash(Poisson(lamb=jnp.array(10))) == hash(Poisson(lamb=jnp.array(10)))

0 comments on commit 06378ed

Please sign in to comment.