Skip to content

Commit

Permalink
fix doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 30, 2023
1 parent 796b910 commit 47cf87d
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 187 deletions.
148 changes: 2 additions & 146 deletions src/dilax/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,6 @@ def __init__(self, model: Model, observation: jax.Array) -> None:
class NLL(BaseModule):
"""
Negative log-likelihood (NLL).
Example:
```
from dilax.model import Model
from dilax.parameter import r, lnN
class MyModel(Model):
def evaluate(self) -> EvaluationResult:
expectations = {}
# signal
signal, mu_penalty = self.parameters["mu"](self.processes["signal"], type="r")
expectations["signal"] = signal
# background
background, sigma_penalty = self.parameters["sigma"](self.processes["background"], type="lnN", width=1.1)
expectations["background"] = background
return EvaluationResult(expectations=expectations, penalty=mu_penalty + sigma_penalty)
model = MyModel(
processes={"signal": jnp.array([10]), "background": jnp.array([50])},
parameters={"mu": Parameter(value=1.0, bounds=(0, 100)), "sigma": Parameter(value=0, bounds=(-100, 100))},
)
observation = jnp.array([60])
nll = NLL(model=model, observation=observation)
# evaluate the negative log likelihood
%timeit nll(values={"mu": jnp.array([1.1]), "sigma": jnp.array([0.8])})
>> 2.03 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# evaluate the negative log likelihood *fast*
%timeit eqx.filter_jit(nll)(values={"mu": jnp.array([1.1]), "sigma": jnp.array([0.8])}).block_until_ready()
>> 274 µs ± 3.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
"""

def logpdf(self, *args, **kwargs) -> jax.Array:
Expand All @@ -81,41 +47,7 @@ def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Ar

class Hessian(BaseModule):
"""
Covariance Matrix.
Example:
```
from dilax.model import Model
from dilax.parameter import r, lnN
class MyModel(Model):
def evaluate(self) -> EvaluationResult:
expectations = {}
# signal
signal, mu_penalty = self.parameters["mu"](self.processes["signal"], type="r")
expectations["signal"] = signal
# background
background, sigma_penalty = self.parameters["sigma"](self.processes["background"], type="lnN", width=1.1)
expectations["background"] = background
return EvaluationResult(expectations=expectations, penalty=mu_penalty + sigma_penalty)
model = MyModel(
processes={"signal": jnp.array([10]), "background": jnp.array([50])},
parameters={"mu": Parameter(value=1.0, bounds=(0, 100)), "sigma": Parameter(value=0, bounds=(-100, 100))},
)
observation = jnp.array([60])
hessian = Hessian(model=model, observation=observation)
# evaluate the negative log likelihood
%timeit hessian(values={"mu": jnp.array([1.1]), "sigma": jnp.array([0.8])})
>> 18.8 ms ± 470 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# evaluate the negative log likelihood *fast*
%timeit eqx.filter_jit(hessian)(values={"mu": jnp.array([1.1]), "sigma": jnp.array([0.8])}).block_until_ready()
>> 325 µs ± 2.72 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
Hessian matrix.
"""

NLL: NLL
Expand All @@ -140,41 +72,7 @@ def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Ar

class CovMatrix(Hessian):
"""
Covariance Matrix.
Example:
```
from dilax.model import Model
from dilax.parameter import r, lnN
class MyModel(Model):
def evaluate(self) -> EvaluationResult:
expectations = {}
# signal
signal, mu_penalty = self.parameters["mu"](self.processes["signal"], type="r")
expectations["signal"] = signal
# background
background, sigma_penalty = self.parameters["sigma"](self.processes["background"], type="lnN", width=1.1)
expectations["background"] = background
return EvaluationResult(expectations=expectations, penalty=mu_penalty + sigma_penalty)
model = MyModel(
processes={"signal": jnp.array([10]), "background": jnp.array([50])},
parameters={"mu": Parameter(value=1.0, bounds=(0, 100)), "sigma": Parameter(value=0, bounds=(-100, 100))},
)
observation = jnp.array([60])
covmatrix = CovMatrix(model=model, observation=observation)
# evaluate the negative log likelihood
%timeit covmatrix(values={"mu": jnp.array([1.1]), "sigma": jnp.array([0.8])})
>> 19 ms ± 504 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# evaluate the negative log likelihood *fast*
%timeit eqx.filter_jit(covmatrix)(values={"mu": jnp.array([1.1]), "sigma": jnp.array([0.8])}).block_until_ready()
>> 327 µs ± 1.78 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
Covariance matrix.
"""

def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array:
Expand All @@ -187,48 +85,6 @@ def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Ar
class SampleToy(BaseModule):
"""
Sample a toy from the model.
Example:
```
from dilax.model import Model
from dilax.parameter import r, lnN
class MyModel(Model):
def evaluate(self) -> EvaluationResult:
expectations = {}
# signal
signal, mu_penalty = self.parameters["mu"](self.processes["signal"], type="r")
expectations["signal"] = signal
# background
background, sigma_penalty = self.parameters["sigma"](self.processes["background"], type="lnN", width=1.1)
expectations["background"] = background
return EvaluationResult(expectations=expectations, penalty=mu_penalty + sigma_penalty)
model = MyModel(
processes={"signal": jnp.array([10]), "background": jnp.array([50])},
parameters={"mu": Parameter(value=1.0, bounds=(0, 100)), "sigma": Parameter(value=0, bounds=(-100, 100))},
)
observation = jnp.array([60])
sample_toy = SampleToy(model=model, observation=observation)
values = {"mu": jnp.array([1.1]), "sigma": jnp.array([0.8])}
# sample a single toy
toy = sample_toy(values=values, key=jax.random.PRNGKey(1234))
# sample a single toy *fast*
toy = eqx.filter_jit(sample_toy)(values=values, key=jax.random.PRNGKey(1234))
# sample 10 toys
keys = jax.random.split(jax.random.PRNGKey(1234), num=10)
toys = eqx.filter_vmap(in_axes=(None, 0))(eqx.filter_jit(sample_toy))(values, keys)
# new model from toy
toy = eqx.filter_jit(sample_toy)(values=values, key=jax.random.PRNGKey(1234))
new_model = model.update(processes=toy)
```
"""

CovMatrix: CovMatrix
Expand Down
56 changes: 33 additions & 23 deletions src/dilax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@


class Result(eqx.Module):
"""
Holds:
dict[str, jax.Array]: The expected number of events in each bin for each process.
"""

expectations: dict[str, jax.Array]

def __init__(self) -> None:
Expand All @@ -39,45 +34,60 @@ class Model(eqx.Module):
It is requires to implement the `evaluate` method, which returns an `EvaluationResult` object.
Example:
```
# Simple model with two processes and two parameters
.. code-block:: python
import jax
import jax.numpy as jnp
import equinox as eqx
from dilax.model import Model, Result
from dilax.parameter import Parameter, lnN, modifier, unconstrained
from dilax.util import HistDB
# Define a simple model with two processes and two parameters
class MyModel(Model):
def evaluate(self) -> EvaluationResult:
expectations = {}
def __call__(self, processes: HistDB, parameters: dict[str, Parameter]) -> Result:
res = Result()
# signal
signal, mu_penalty = self.parameters["mu"](self.processes["signal"], type="r")
expectations["signal"] = signal
mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained())
res.add(process="signal", expectation=mu_mod(self.processes["signal"]))
# background
background, sigma_penalty = self.parameters["sigma"](self.processes["background"], type="lnN", width=1.1)
expectations["background"] = background
return EvaluationResult(expectations=expectations, penalty=mu_penalty + sigma_penalty)
bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=lnN(1.1))
res.add(process="background", expectation=bkg_mod(self.processes["background"]))
return res
model = MyModel(
processes={"signal": jnp.array([10]), "background": jnp.array([50])},
parameters={"mu": Parameter(value=1.0, bounds=(0, 100)), "sigma": Parameter(value=0, bounds=(-100, 100))},
)
# Setup model
processes = HistDB({"signal": jnp.array([10]), "background": jnp.array([50])})
parameters = {
"mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)
# evaluate the expectation
model.evaluate().expectation()
>> Array([60.], dtype=float32, weak_type=True)
# -> Array([60.], dtype=float32)
%timeit model.evaluate().expectation()
>> 245 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# -> 3.05 ms ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# evaluate the expectation *fast*
@eqx.filter_jit
def eval(model) -> jax.Array:
res = model.evaluate()
return res.expectation()
eqx.filter_jit(eval)(model)
>> Array([60.], dtype=float32, weak_type=True)
# -> Array([60.], dtype=float32)
%timeit eqx.filter_jit(eval)(model).block_until_ready()
>> 96.9 µs ± 778 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
# -> 114 µs ± 327 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
"""

processes: HistDB
Expand Down
10 changes: 6 additions & 4 deletions src/dilax/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ class JaxOptimizer(eqx.Module):
This allows to pass the optimizer as a parameter to a `jax.jit` function, and setup the optimizer therein.
Example:
```
.. code-block:: python
optimizer = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5})
# or, e.g.: optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10})
optimizer.fit(fun=nll, init_values=init_values)
```
"""

name: str
Expand Down Expand Up @@ -56,14 +57,15 @@ class Chain(eqx.Module):
in order to have a deterministic runtime behaviour.
Example:
```
.. code-block:: python
opt1 = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5})
opt2 = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10})
chain = Chain(opt1, opt2)
# first 5 steps are minimized with GradientDescent, then 10 steps with LBFGS
chain.fit(fun=nll, init_values=init_values)
```
"""

optimizers: tuple[JaxOptimizer, ...]
Expand Down
14 changes: 7 additions & 7 deletions src/dilax/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ class modifier(ModifierBase):
Create a new modifier for a given parameter and penalty.
Example:
```
.. code-block:: python
from dilax.parameter import modifier, Parameter, unconstrained
mu = Parameter(value=1.1, bounds=(0, 100))
Expand All @@ -185,9 +187,7 @@ class modifier(ModifierBase):
# apply the modifier
modify(jnp.array([10, 20, 30]))
>> ('mu',
Array([11., 22., 33.], dtype=float32, weak_type=True),
Array([0.], dtype=float32))
# -> Array([11., 22., 33.], dtype=float32, weak_type=True),
# lnN effect
norm = Parameter(value=0.0, bounds=(-jnp.inf, jnp.inf))
Expand All @@ -204,7 +204,6 @@ class modifier(ModifierBase):
down = jnp.array([8, 19, 26])
modify = modifier(name="norm", parameter=norm, effect=shape(up, down))
modify(jnp.array([10, 20, 30]))
```
"""

name: str
Expand Down Expand Up @@ -232,7 +231,9 @@ class compose(ModifierBase):
It behaves like a single modifier, but it is composed of multiple modifiers; it can be arbitrarly nested.
Example:
```
.. code-block:: python
from dilax.parameter import modifier, compose, Parameter, FreeFloating, LogNormal
mu = Parameter(value=1.1, bounds=(0, 100))
Expand All @@ -255,7 +256,6 @@ class compose(ModifierBase):
# jit
eqx.filter_jit(composition)(jnp.array([10, 20, 30]))
```
"""

modifiers: tuple[modifier, ...]
Expand Down
14 changes: 7 additions & 7 deletions src/dilax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ class FrozenDB(Mapping[K, V]):
"""An immutable database-like custom dict.
Example:
```
.. code-block:: python
hists = HistDB(
{
# QCD
Expand All @@ -105,7 +107,7 @@ class FrozenDB(Mapping[K, V]):
)
print(hists)
# >> HistDB({
# -> HistDB({
# ('QCD', 'nominal'): Array([1, 1, 1, 1, 1], dtype=int32),
# ('QCD', 'Up', 'JES'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32),
# ('QCD', 'Down', 'JES'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32),
Expand All @@ -115,14 +117,14 @@ class FrozenDB(Mapping[K, V]):
# })
print(hists["QCD"])
# >> HistDB({
# -> HistDB({
# 'nominal': Array([1, 1, 1, 1, 1], dtype=int32),
# ('Up', 'JES'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32),
# ('Down', 'JES'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32),
# })
print(hists["JES"])
# >> HistDB({
# -> HistDB({
# ('QCD', 'Up'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32),
# ('QCD', 'Down'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32),
# ('DY', 'Up'): Array([2.5, 2.5, 2.5, 2.5, 2.5], dtype=float32),
Expand All @@ -134,9 +136,7 @@ def foo(hists):
return (hists["QCD", "nominal"] + 1.2) ** 2
print(jax.jit(foo)(hists))
# >> Array([4.84, 4.84, 4.84, 4.84, 4.84], dtype=float32, weak_type=True)
```
# -> Array([4.84, 4.84, 4.84, 4.84, 4.84], dtype=float32, weak_type=True)
"""

__slots__ = ("_dict",)
Expand Down

0 comments on commit 47cf87d

Please sign in to comment.