Skip to content

Commit

Permalink
feat: add intermediate Model and serialization updates
Browse files Browse the repository at this point in the history
Big commit which has BREAKING CHANGES! This adds an intermediate struct called `Model` which has `serde` implementations. `Model`s can be saved and loaded and completely contain everything needed to make an `Evaluator` or `NLL` except for the `Dataset`.

To do this, `Amplitude`s need to be (de)serializable. This is a bit tricky, but the `typetag` crate does the trick and it isn't too much of a burden to throw `#[typetag::serde]` on `Amplitude` impls.
  • Loading branch information
denehoffman committed Dec 6, 2024
1 parent 3c27f69 commit 561753b
Show file tree
Hide file tree
Showing 28 changed files with 682 additions and 304 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ name = "laddu"
crate-type = ["cdylib", "rlib"]

[dependencies]
indexmap = "2.6.0"
indexmap = { version = "2.6.0", features = ["serde"] }
num = "0.4.3"
nalgebra = "0.33.2"
arrow = "53.3.0"
Expand All @@ -40,6 +40,8 @@ thiserror = "2.0.3"
shellexpand = "3.1.0"
accurate = "0.4.1"
serde = "1.0.215"
serde_with = "3.11.0"
typetag = "0.2.18"
serde-pickle = "1.2.0"
bincode = "1.3.3"

Expand Down
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@ is the relativistic width correction, $`q(m_a, m_b, m_c)`$ is the breakup moment
Although this particular amplitude is already included in `laddu`, let's assume it isn't and imagine how we would write it from scratch:

```rust
use laddu::prelude::*;
use laddu::{
ParameterLike, Event, Cache, Resources, Mass,
ParameterID, Parameters, Float, LadduError, PI, AmplitudeID, Complex,
};
use laddu::traits::*;
use laddu::utils::functions::{blatt_weisskopf, breakup_momentum};
use laddu::{Deserialize, Serialize};

#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
pub struct MyBreitWigner {
name: String,
mass: ParameterLike,
Expand Down Expand Up @@ -124,6 +129,7 @@ impl MyBreitWigner {
}
}

#[typetag::serde]
impl Amplitude for MyBreitWigner {
fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
self.pid_mass = resources.register_parameter(&self.mass);
Expand Down Expand Up @@ -151,6 +157,7 @@ impl Amplitude for MyBreitWigner {
### Calculating a Likelihood
We could then write some code to use this amplitude. For demonstration purposes, let's just calculate an extended unbinned negative log-likelihood, assuming we have some data and Monte Carlo in the proper [parquet format](#data-format):
```rust
use laddu::{Scalar, Mass, Manager, NLL, parameter, open};
let ds_data = open("test_data/data.parquet").unwrap();
let ds_mc = open("test_data/mc.parquet").unwrap();

Expand All @@ -168,11 +175,12 @@ let bw = manager.register(MyBreitWigner::new(
&resonance_mass,
)).unwrap();
let mag = manager.register(Scalar::new("mag", parameter("magnitude"))).unwrap();
let model = (mag * bw).norm_sqr();
let expr = (mag * bw).norm_sqr();
let model = manager.model(&expr);

let nll = NLL::new(&manager, &ds_data, &ds_mc);
let nll = NLL::new(&model, &ds_data, &ds_mc);
println!("Parameters names and order: {:?}", nll.parameters());
let result = nll.evaluate(&[1.27, 0.120, 100.0], &model);
let result = nll.evaluate(&[1.27, 0.120, 100.0]);
println!("The extended negative log-likelihood is {}", result);
```
In practice, amplitudes can also be added together, their real and imaginary parts can be taken, and evaluators should mostly take the real part of whatever complex value comes out of the model.
Expand Down Expand Up @@ -203,9 +211,10 @@ def main():
pos_im = (s0p * z00p.imag() + d2p * z22p.imag()).norm_sqr()
neg_re = (s0n * z00n.real()).norm_sqr()
neg_im = (s0n * z00n.imag()).norm_sqr()
model = pos_re + pos_im + neg_re + neg_im
expr = pos_re + pos_im + neg_re + neg_im
model = manager.model(expr)

nll = ld.NLL(manager, model, ds_data, ds_mc)
nll = ld.NLL(model, ds_data, ds_mc)
status = nll.minimize([1.0] * len(nll.parameters))
print(status)
fit_weights = nll.project(status.x)
Expand Down
5 changes: 3 additions & 2 deletions benches/kmatrix_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ fn kmatrix_nll_benchmark(c: &mut Criterion) {
let pos_im = (&s0p * z00p.imag() + &d2p * z22p.imag()).norm_sqr();
let neg_re = (&s0n * z00n.real()).norm_sqr();
let neg_im = (&s0n * z00n.imag()).norm_sqr();
let model = pos_re + pos_im + neg_re + neg_im;
let nll = NLL::new(&manager, &model, &ds_data, &ds_mc);
let expr = pos_re + pos_im + neg_re + neg_im;
let model = manager.model(&expr);
let nll = NLL::new(&model, &ds_data, &ds_mc);
let mut group = c.benchmark_group("K-Matrix NLL Performance");
for threads in 1..=num_cpus::get() {
let pool = ThreadPoolBuilder::new()
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/unbinned_fit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ Next, we combine these together according to our model. For these amplitude poin
positive_real_sum = (f0_1500 * bw0 * z00p.real() + f2_1525 * bw2 * z22p.real()).norm_sqr()
positive_imag_sum = (f0_1500 * bw0 * z00p.imag() + f2_1525 * bw2 * z22p.imag()).norm_sqr()
model = positive_real_sum + positive_imag_sum
model = manager.model(positive_real_sum + positive_imag_sum)
Now that we have the model, we want to fit the free parameters, which in this case are the complex photocouplings and the widths of each Breit-Wigner. We can do this by creating an ``NLL`` object which uses the data and accepted Monte-Carlo datasets to calculate the negative log-likelihood described earlier.

.. code:: python
nll = ld.NLL(manager, model, data_ds, accmc_ds)
nll = ld.NLL(model, data_ds, accmc_ds)
print(nll.parameters)
# ['Re[f_0(1500)]', "Re[f_2'(1525)]", "Im[f_2'(1525)]", 'f_0 width', 'f_2 width']
Expand Down Expand Up @@ -265,7 +265,7 @@ To create an ``Evaluator`` object, we just need to load up the manager with the

.. code:: python
gen_eval = manager.load(model, genmc_ds)
gen_eval = model.load(genmc_ds)
tot_weights_acc = nll.project(status.x, mc_evaluator=gen_eval)
f0_weights_acc = nll.project_with(status.x, ["[f_0(1500)]", "BW_0", "Z00+"], mc_evaluator=gen_eval)
f2_weights_acc = nll.project_with(status.x, ["[f_2'(1525)]", "BW_2", "Z22+"], mc_evaluator=gen_eval)
Expand Down
3 changes: 2 additions & 1 deletion python/laddu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from laddu.amplitudes import Manager, constant, parameter
from laddu.amplitudes import Manager, constant, parameter, Model
from laddu.amplitudes.breit_wigner import BreitWigner
from laddu.amplitudes.common import ComplexScalar, PolarComplexScalar, Scalar
from laddu.amplitudes.ylm import Ylm
Expand Down Expand Up @@ -76,6 +76,7 @@ def open_amptools(
'Event',
'LikelihoodManager',
'Manager',
'Model',
'Mandelstam',
'Mass',
'Observer',
Expand Down
3 changes: 2 additions & 1 deletion python/laddu/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from pathlib import Path

from laddu.amplitudes import Expression, Manager, constant, parameter
from laddu.amplitudes import Expression, Manager, constant, parameter, Model
from laddu.amplitudes.breit_wigner import BreitWigner
from laddu.amplitudes.common import ComplexScalar, PolarComplexScalar, Scalar
from laddu.amplitudes.ylm import Ylm
Expand Down Expand Up @@ -51,6 +51,7 @@ __all__ = [
'Expression',
'LikelihoodManager',
'Manager',
'Model',
'Mandelstam',
'Mass',
'Observer',
Expand Down
2 changes: 2 additions & 0 deletions python/laddu/amplitudes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Evaluator,
Expression,
Manager,
Model,
ParameterLike,
constant,
parameter,
Expand All @@ -15,6 +16,7 @@
'Expression',
'Amplitude',
'Manager',
'Model',
'Evaluator',
'ParameterLike',
'parameter',
Expand Down
15 changes: 12 additions & 3 deletions python/laddu/amplitudes/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ class Manager:
parameters: list[str]
def __init__(self) -> None: ...
def register(self, amplitude: Amplitude) -> AmplitudeID: ...
def load(
self, expression: Expression | AmplitudeID, dataset: Dataset
) -> Evaluator: ...
def model(self, expression: Expression | AmplitudeID) -> Model: ...

class Model:
parameters: list[str]
def __init__(self) -> None: ...
def load(self, dataset: Dataset) -> Evaluator: ...
def save_as(self, path: str) -> None: ...
@staticmethod
def load_from(path: str) -> Model: ...
def __getstate__(self) -> object: ...
def __setstate__(self, state: object) -> None: ...

class Evaluator:
parameters: list[str]
Expand All @@ -56,6 +64,7 @@ __all__ = [
'Expression',
'Amplitude',
'Manager',
'Model',
'Evaluator',
'ParameterLike',
'parameter',
Expand Down
2 changes: 1 addition & 1 deletion python/laddu/likelihoods/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Status:
def __init__(self) -> None: ...
def save_as(self, path: str) -> None: ...
@staticmethod
def load(path: str) -> Status: ...
def load_from(path: str) -> Status: ...
def as_dict(self) -> dict[str, Any]: ...
def __getstate__(self) -> object: ...
def __setstate__(self, state: object) -> None: ...
Expand Down
48 changes: 33 additions & 15 deletions python/tests/test_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def test_constant_amplitude():
amp = Scalar('constant', constant(2.0))
aid = manager.register(amp)
dataset = make_test_dataset()
evaluator = manager.load(aid, dataset)
model = manager.model(aid)
evaluator = model.load(dataset)
result = evaluator.evaluate([])
assert result[0] == 2.0 + 0.0j

Expand All @@ -43,7 +44,8 @@ def test_parametric_amplitude():
amp = Scalar('parametric', parameter('test_param'))
aid = manager.register(amp)
dataset = make_test_dataset()
evaluator = manager.load(aid, dataset)
model = manager.model(aid)
evaluator = model.load(dataset)
result = evaluator.evaluate([3.0])
assert result[0] == 3.0 + 0.0j

Expand All @@ -58,52 +60,62 @@ def test_expression_operations():
aid3 = manager.register(amp3)
dataset = make_test_dataset()
expr_add = aid1 + aid2
eval_add = manager.load(expr_add, dataset)
model_add = manager.model(expr_add)
eval_add = model_add.load(dataset)
result_add = eval_add.evaluate([])
assert result_add[0] == 2.0 + 1.0j

expr_mul = aid1 * aid2
eval_mul = manager.load(expr_mul, dataset)
model_mul = manager.model(expr_mul)
eval_mul = model_mul.load(dataset)
result_mul = eval_mul.evaluate([])
assert result_mul[0] == 0.0 + 2.0j

expr_add2 = expr_add + expr_mul
eval_add2 = manager.load(expr_add2, dataset)
model_add2 = manager.model(expr_add2)
eval_add2 = model_add2.load(dataset)
result_add2 = eval_add2.evaluate([])
assert result_add2[0] == 2.0 + 3.0j

expr_mul2 = expr_add * expr_mul
eval_mul2 = manager.load(expr_mul2, dataset)
model_mul2 = manager.model(expr_mul2)
eval_mul2 = model_mul2.load(dataset)
result_mul2 = eval_mul2.evaluate([])
assert result_mul2[0] == -2.0 + 4.0j

expr_real = aid3.real()
eval_real = manager.load(expr_real, dataset)
model_real = manager.model(expr_real)
eval_real = model_real.load(dataset)
result_real = eval_real.evaluate([])
assert result_real[0] == 3.0 + 0.0j

expr_mul2_real = expr_mul2.real()
eval_mul2_real = manager.load(expr_mul2_real, dataset)
model_mul2_real = manager.model(expr_mul2_real)
eval_mul2_real = model_mul2_real.load(dataset)
result_mul2_real = eval_mul2_real.evaluate([])
assert result_mul2_real[0] == -2.0 + 0.0j

expr_imag = aid3.imag()
eval_imag = manager.load(expr_imag, dataset)
model_imag = manager.model(expr_imag)
eval_imag = model_imag.load(dataset)
result_imag = eval_imag.evaluate([])
assert result_imag[0] == 4.0 + 0.0j

expr_mul2_imag = expr_mul2.imag()
eval_mul2_imag = manager.load(expr_mul2_imag, dataset)
model_mul2_imag = manager.model(expr_mul2_imag)
eval_mul2_imag = model_mul2_imag.load(dataset)
result_mul2_imag = eval_mul2_imag.evaluate([])
assert result_mul2_imag[0] == 4.0 + 0.0j

expr_norm = aid1.norm_sqr()
eval_norm = manager.load(expr_norm, dataset)
model_norm = manager.model(expr_norm)
eval_norm = model_norm.load(dataset)
result_norm = eval_norm.evaluate([])
assert result_norm[0] == 4.0 + 0.0j

expr_mul2_norm = expr_mul2.norm_sqr()
eval_mul2_norm = manager.load(expr_mul2_norm, dataset)
model_mul2_norm = manager.model(expr_mul2_norm)
eval_mul2_norm = model_mul2_norm.load(dataset)
result_mul2_norm = eval_mul2_norm.evaluate([])
assert result_mul2_norm[0] == 20.0 + 0.0j

Expand All @@ -117,7 +129,8 @@ def test_amplitude_activation():
dataset = make_test_dataset()

expr = aid1 + aid2
evaluator = manager.load(expr, dataset)
model = manager.model(expr)
evaluator = model.load(dataset)
result = evaluator.evaluate([])
assert result[0] == 3.0 + 0.0j

Expand All @@ -140,7 +153,8 @@ def test_gradient():
aid = manager.register(amp)
dataset = make_test_dataset()
expr = aid.norm_sqr()
evaluator = manager.load(expr, dataset)
model = manager.model(expr)
evaluator = model.load(dataset)
params = [2.0]
gradient = evaluator.evaluate_gradient(params)
# For |f(x)|^2 where f(x) = x, the derivative should be 2x
Expand All @@ -151,10 +165,14 @@ def test_gradient():
def test_parameter_registration():
manager = Manager()
amp = Scalar('parametric', parameter('test_param'))
manager.register(amp)
aid = manager.register(amp)
parameters = manager.parameters
model = manager.model(aid)
model_parameters = model.parameters
assert len(parameters) == 1
assert parameters[0] == 'test_param'
assert len(model_parameters) == 1
assert model_parameters[0] == 'test_param'


def test_duplicate_amplitude_registration():
Expand Down
6 changes: 4 additions & 2 deletions python/tests/test_breit_wigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def test_bw_evaluation():
)
aid = manager.register(amp)
dataset = make_test_dataset()
evaluator = manager.load(aid, dataset)
model = manager.model(aid)
evaluator = model.load(dataset)
result = evaluator.evaluate([1.5, 0.3])
assert pytest.approx(result[0].real) == 1.4585691
assert pytest.approx(result[0].imag) == 1.4107341
Expand All @@ -39,7 +40,8 @@ def test_bw_gradient():
)
aid = manager.register(amp)
dataset = make_test_dataset()
evaluator = manager.load(aid, dataset)
model = manager.model(aid)
evaluator = model.load(dataset)
result = evaluator.evaluate_gradient([1.5, 0.3])
assert pytest.approx(result[0][0].real) == 1.3252039
assert pytest.approx(result[0][0].imag) == -11.6827505
Expand Down
Loading

0 comments on commit 561753b

Please sign in to comment.