Skip to content

Commit

Permalink
add sentinels
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 29, 2023
1 parent 4225673 commit bfe8cb0
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ select = [
extend-ignore = [
"PLR", # Design related pylint codes
"E501", # Line too long
"B006", # converts default args to 'None'
# "B006", # converts default args to 'None'
]

src = ["src"]
Expand Down
27 changes: 22 additions & 5 deletions src/dilax/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from typing import cast

import equinox as eqx
import jax
import jax.numpy as jnp

from dilax.model import Model
from dilax.util import Sentinel

_MISSING = Sentinel("<MISSING>")


class BaseModule(eqx.Module):
Expand Down Expand Up @@ -62,7 +67,9 @@ def evaluate(self) -> EvaluationResult:
def logpdf(self, *args, **kwargs) -> jax.Array:
return jax.scipy.stats.poisson.logpmf(*args, **kwargs)

def __call__(self, values: dict[str, jax.Array] = {}) -> jax.Array:
def __call__(self, values: dict[str, jax.Array] | Sentinel = _MISSING) -> jax.Array:
if values is _MISSING:
values = {}
model = self.model.update(values=values)
res = model.evaluate()
nll = (
Expand Down Expand Up @@ -119,9 +126,12 @@ def __init__(self, model: Model, observation: jax.Array) -> None:
super().__init__(model=model, observation=observation)
self.NLL = NLL(model=model, observation=observation)

def __call__(self, values: dict[str, jax.Array] = {}) -> jax.Array:
def __call__(self, values: dict[str, jax.Array] | Sentinel = _MISSING) -> jax.Array:
if values is _MISSING:
values = {}
if not values:
values = self.model.parameter_values
values = cast(dict[str, jax.Array], values)
hessian = jax.hessian(self.NLL, argnums=0)(values)
hessian, _ = jax.tree_util.tree_flatten(hessian)
hessian = jnp.array(hessian)
Expand Down Expand Up @@ -168,7 +178,9 @@ def evaluate(self) -> EvaluationResult:
```
"""

def __call__(self, values: dict[str, jax.Array] = {}) -> jax.Array:
def __call__(self, values: dict[str, jax.Array] | Sentinel = _MISSING) -> jax.Array:
if values is _MISSING:
values = {}
hessian = super().__call__(values=values)
return jnp.linalg.inv(-hessian)

Expand Down Expand Up @@ -227,10 +239,15 @@ def __init__(self, model: Model, observation: jax.Array) -> None:
self.CovMatrix = CovMatrix(model=model, observation=observation)

def __call__(
self, values: dict[str, jax.Array] = {}, key: jax.Array | None = None
self,
values: dict[str, jax.Array] | Sentinel = _MISSING,
key: jax.Array | Sentinel = _MISSING,
) -> dict[str, jax.Array]:
if key is None:
if values is _MISSING:
values = {}
if key is _MISSING:
key = jax.random.PRNGKey(1234)
key = cast(jax.Array, key)
if not values:
values = self.model.parameter_values
cov = self.CovMatrix(values=values)
Expand Down
20 changes: 16 additions & 4 deletions src/dilax/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import abc
from typing import cast

import equinox as eqx
import jax
import jax.numpy as jnp

from dilax.parameter import Parameter
from dilax.util import FrozenDB, HistDB
from dilax.util import FrozenDB, HistDB, Sentinel

_MISSING = Sentinel("<MISSING>")


class Result(eqx.Module):
Expand All @@ -18,8 +21,8 @@ class Result(eqx.Module):

expectations: dict[str, jax.Array]

def __init__(self, expectations: dict[str, jax.Array] = {}) -> None:
self.expectations = expectations
def __init__(self) -> None:
self.expectations = {}

def add(self, process: str, expectation: jax.Array) -> Result:
self.expectations[process] = expectation
Expand Down Expand Up @@ -104,11 +107,20 @@ def parameter_constraints(self) -> jax.Array:
return jnp.sum(jnp.array(c))

def update(
self, processes: dict | HistDB = {}, values: dict[str, jax.Array] = {}
self,
processes: dict | HistDB | Sentinel = _MISSING,
values: dict[str, jax.Array] | Sentinel = _MISSING,
) -> Model:
if values is _MISSING:
values = {}
if processes is _MISSING:
processes = {}
if not isinstance(processes, HistDB):
processes = HistDB(processes)

processes = cast(HistDB, processes)
values = cast(dict[str, jax.Array], values)

def _patch_processes(processes: HistDB) -> HistDB:
assert isinstance(processes, HistDB)
new_processes = dict(self.processes.items())
Expand Down
6 changes: 3 additions & 3 deletions src/dilax/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def constraint(self) -> HashablePDF:
return Gauss(mean=0.0, width=1.0)

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
gx = Gauss(mean=1.0, width=self.width)
gx = Gauss(mean=1.0, width=self.width) # type: ignore[arg-type]
g1 = Gauss(mean=1.0, width=1.0)
return gx.inv_cdf(g1.cdf(parameter.value + 1))

Expand Down Expand Up @@ -143,7 +143,7 @@ def constraint(self) -> HashablePDF:
def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
width = self.scale(parameter=parameter)
g1 = Gauss(mean=1.0, width=1.0)
gx = Gauss(mean=1.0, width=width)
gx = Gauss(mean=1.0, width=width) # type: ignore[arg-type]
return g1.inv_cdf(gx.cdf(jnp.exp(parameter.value)))


Expand All @@ -161,7 +161,7 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
gauss_cdf = jnp.broadcast_to(
self.constraint.cdf(parameter.value), self.lamb.shape
)
return Poisson(self.lamb).inv_cdf(gauss_cdf)
return Poisson(self.lamb).inv_cdf(gauss_cdf) # type: ignore[arg-type]


class ModifierBase(eqx.Module):
Expand Down
10 changes: 5 additions & 5 deletions src/dilax/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def inv_cdf(self, x: jax.Array) -> jax.Array:


class Gauss(HashablePDF):
mean: float | jax.Array = eqx.field(static=True)
width: float | jax.Array = eqx.field(static=True)
mean: float = eqx.field(static=True)
width: float = eqx.field(static=True)

def __init__(self, mean: float | jax.Array, width: float | jax.Array) -> None:
def __init__(self, mean: float, width: float) -> None:
self.mean = mean
self.width = width

Expand All @@ -76,9 +76,9 @@ def inv_cdf(self, x: jax.Array) -> jax.Array:


class Poisson(HashablePDF):
lamb: float | jax.Array = eqx.field(static=True)
lamb: int = eqx.field(static=True)

def __init__(self, lamb: float | jax.Array) -> None:
def __init__(self, lamb: int) -> None:
self.lamb = lamb

def __hash__(self):
Expand Down
47 changes: 42 additions & 5 deletions src/dilax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,42 @@
import collections
import pprint
from collections.abc import Hashable, Iterable, Mapping
from typing import Any, Callable, TypeVar
from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypeVar, cast

import jax
import jax.numpy as jnp


class Sentinel:
__slots__ = ("repr",)

_instances: ClassVar[dict[tuple[type[Sentinel], str], Sentinel]] = {}

if TYPE_CHECKING:
repr: str

def __new__(cls, repr: str) -> Sentinel:
key = (cls, repr)
if key in cls._instances:
sentinel = cls._instances[key]
else:
sentinel = super().__new__(cls)
sentinel.repr = repr
cls._instances[key] = sentinel
return sentinel

def __repr__(self) -> str:
return self.repr

def __reduce__(self) -> tuple[type[Sentinel], tuple[str]]:
return (self.__class__, (self.repr,))

__str__ = __repr__


_MISSING = Sentinel("<MISSING>")


class FrozenKeysView(collections.abc.KeysView):
"""FrozenKeysView that does not print values when repr'ing."""

Expand Down Expand Up @@ -126,6 +156,9 @@ def foo(hists):

__slots__ = ("_dict",)

if TYPE_CHECKING:
_dict: dict[frozenset, Any]

@staticmethod
def keyify(keyish: Any) -> frozenset:
if not isinstance(keyish, (tuple, list, set, frozenset)):
Expand All @@ -136,14 +169,18 @@ def keyify(keyish: Any) -> frozenset:
return keyish

def __init__(
self, xs: dict | FrozenDB = {}, __unsafe_skip_copy__: bool = False
self,
xs: Mapping | Sentinel = _MISSING,
__unsafe_skip_copy__: bool = False,
) -> None:
# make sure the dict is as
xs = dict(xs)
if xs is _MISSING:
xs = {}
data = dict(cast(Mapping, xs))
if __unsafe_skip_copy__:
self._dict = xs
self._dict = data
else:
self._dict = _prepare_freeze(xs)
self._dict = _prepare_freeze(data)

def __getitem__(self, key) -> Any:
key = self.keyify(key)
Expand Down

0 comments on commit bfe8cb0

Please sign in to comment.