Skip to content

Commit

Permalink
Merge pull request #18 from denehoffman/development
Browse files Browse the repository at this point in the history
I/O and other small fixes
  • Loading branch information
denehoffman authored Nov 19, 2024
2 parents ac2bc48 + a8f6323 commit 79188c3
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 20 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ shellexpand = "3.1.0"
accurate = "0.4.1"
serde = "1.0.214"
serde-pickle = "1.1.1"
bincode = "1.3.3"

[dev-dependencies]
approx = "0.5.1"
Expand Down
8 changes: 6 additions & 2 deletions python/laddu/amplitudes/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@ class AmplitudeID:
def real(self) -> Expression: ...
def imag(self) -> Expression: ...
def norm_sqr(self) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __radd__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __mul__(self, other: AmplitudeID | Expression) -> Expression: ...
def __rmul__(self, other: AmplitudeID | Expression) -> Expression: ...

class Expression:
def real(self) -> Expression: ...
def imag(self) -> Expression: ...
def norm_sqr(self) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __radd__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __mul__(self, other: AmplitudeID | Expression) -> Expression: ...
def __rmul__(self, other: AmplitudeID | Expression) -> Expression: ...

class Amplitude: ...

Expand Down
16 changes: 12 additions & 4 deletions python/laddu/likelihoods/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Literal
from typing import Any, Literal

import numpy as np
import numpy.typing as npt
Expand All @@ -8,12 +8,16 @@ from laddu.amplitudes import Expression, Manager
from laddu.data import Dataset

class LikelihoodID:
def __add__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __add__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __radd__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __mul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __rmul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...

class LikelihoodExpression:
def __add__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __add__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __radd__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __mul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __rmul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...

class LikelihoodTerm: ...

Expand Down Expand Up @@ -78,9 +82,13 @@ class Status:
n_f_evals: int
n_g_evals: int

def save_as(self, path: str): ...
def __init__(self) -> None: ...
def save_as(self, path: str) -> None: ...
@staticmethod
def load(path: str) -> Status: ...
def as_dict(self) -> dict[str, Any]: ...
def __getstate__(self) -> object: ...
def __setstate__(self, state: object) -> None: ...

class Bound:
lower: float
Expand Down
3 changes: 2 additions & 1 deletion python/laddu/utils/vectors/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class Vector3:
py: float
pz: float
def __init__(self, px: float, py: float, pz: float): ...
def __add__(self, other: Vector3) -> Vector3: ...
def __add__(self, other: Vector3 | int) -> Vector3: ...
def __radd__(self, other: Vector3 | int) -> Vector3: ...
def dot(self, other: Vector3) -> float: ...
def cross(self, other: Vector3) -> Vector3: ...
def to_numpy(self) -> npt.NDArray[np.float64]: ...
Expand Down
2 changes: 1 addition & 1 deletion src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ mod tests {
#[test]
fn test_event_p4_sum() {
let event = test_event();
let sum = event.get_p4_sum(&[2, 3]);
let sum = event.get_p4_sum([2, 3]);
assert_relative_eq!(sum[0], event.p4s[2].e() + event.p4s[3].e());
assert_relative_eq!(sum[1], event.p4s[2].px() + event.p4s[3].px());
assert_relative_eq!(sum[2], event.p4s[2].py() + event.p4s[3].py());
Expand Down
Loading

0 comments on commit 79188c3

Please sign in to comment.