diff --git a/Cargo.toml b/Cargo.toml index a088cdf..4cab2fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ shellexpand = "3.1.0" accurate = "0.4.1" serde = "1.0.214" serde-pickle = "1.1.1" +bincode = "1.3.3" [dev-dependencies] approx = "0.5.1" diff --git a/python/laddu/amplitudes/__init__.pyi b/python/laddu/amplitudes/__init__.pyi index 42376a0..f57fca0 100644 --- a/python/laddu/amplitudes/__init__.pyi +++ b/python/laddu/amplitudes/__init__.pyi @@ -13,15 +13,19 @@ class AmplitudeID: def real(self) -> Expression: ... def imag(self) -> Expression: ... def norm_sqr(self) -> Expression: ... - def __add__(self, other: AmplitudeID | Expression) -> Expression: ... + def __add__(self, other: AmplitudeID | Expression | int) -> Expression: ... + def __radd__(self, other: AmplitudeID | Expression | int) -> Expression: ... def __mul__(self, other: AmplitudeID | Expression) -> Expression: ... + def __rmul__(self, other: AmplitudeID | Expression) -> Expression: ... class Expression: def real(self) -> Expression: ... def imag(self) -> Expression: ... def norm_sqr(self) -> Expression: ... - def __add__(self, other: AmplitudeID | Expression) -> Expression: ... + def __add__(self, other: AmplitudeID | Expression | int) -> Expression: ... + def __radd__(self, other: AmplitudeID | Expression | int) -> Expression: ... def __mul__(self, other: AmplitudeID | Expression) -> Expression: ... + def __rmul__(self, other: AmplitudeID | Expression) -> Expression: ... class Amplitude: ... diff --git a/python/laddu/likelihoods/__init__.pyi b/python/laddu/likelihoods/__init__.pyi index 0926c24..d0d0f79 100644 --- a/python/laddu/likelihoods/__init__.pyi +++ b/python/laddu/likelihoods/__init__.pyi @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal +from typing import Any, Literal import numpy as np import numpy.typing as npt @@ -8,12 +8,16 @@ from laddu.amplitudes import Expression, Manager from laddu.data import Dataset class LikelihoodID: - def __add__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ... + def __add__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ... + def __radd__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ... def __mul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ... + def __rmul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ... class LikelihoodExpression: - def __add__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ... + def __add__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ... + def __radd__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ... def __mul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ... + def __rmul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ... class LikelihoodTerm: ... @@ -78,9 +82,13 @@ class Status: n_f_evals: int n_g_evals: int - def save_as(self, path: str): ... + def __init__(self) -> None: ... + def save_as(self, path: str) -> None: ... @staticmethod def load(path: str) -> Status: ... + def as_dict(self) -> dict[str, Any]: ... + def __getstate__(self) -> object: ... + def __setstate__(self, state: object) -> None: ... class Bound: lower: float diff --git a/python/laddu/utils/vectors/__init__.pyi b/python/laddu/utils/vectors/__init__.pyi index 4c93fd1..ec65178 100644 --- a/python/laddu/utils/vectors/__init__.pyi +++ b/python/laddu/utils/vectors/__init__.pyi @@ -12,7 +12,8 @@ class Vector3: py: float pz: float def __init__(self, px: float, py: float, pz: float): ... - def __add__(self, other: Vector3) -> Vector3: ... + def __add__(self, other: Vector3 | int) -> Vector3: ... + def __radd__(self, other: Vector3 | int) -> Vector3: ... def dot(self, other: Vector3) -> float: ... def cross(self, other: Vector3) -> Vector3: ... def to_numpy(self) -> npt.NDArray[np.float64]: ... diff --git a/src/data.rs b/src/data.rs index 2228a19..e95ca35 100644 --- a/src/data.rs +++ b/src/data.rs @@ -532,7 +532,7 @@ mod tests { #[test] fn test_event_p4_sum() { let event = test_event(); - let sum = event.get_p4_sum(&[2, 3]); + let sum = event.get_p4_sum([2, 3]); assert_relative_eq!(sum[0], event.p4s[2].e() + event.p4s[3].e()); assert_relative_eq!(sum[1], event.p4s[2].px() + event.p4s[3].px()); assert_relative_eq!(sum[2], event.p4s[2].py() + event.p4s[3].py()); diff --git a/src/python.rs b/src/python.rs index 7a6a15d..43d128c 100644 --- a/src/python.rs +++ b/src/python.rs @@ -20,6 +20,7 @@ pub(crate) mod laddu { use crate::utils::variables::Variable; use crate::utils::vectors::{FourMomentum, FourVector, ThreeMomentum, ThreeVector}; use crate::Float; + use bincode::{deserialize, serialize}; use ganesh::algorithms::lbfgsb::{LBFGSBFTerminator, LBFGSBGTerminator}; use ganesh::algorithms::nelder_mead::{ NelderMeadFTerminator, NelderMeadXTerminator, SimplexExpansionMethod, @@ -28,6 +29,7 @@ pub(crate) mod laddu { use num::Complex; use numpy::{PyArray1, PyArray2}; use pyo3::exceptions::{PyIndexError, PyTypeError, PyValueError}; + use pyo3::types::PyBytes; use pyo3::types::{PyDict, PyList}; #[pyfunction] @@ -51,11 +53,35 @@ pub(crate) mod laddu { fn new(px: Float, py: Float, pz: Float) -> Self { Self(nalgebra::Vector3::new(px, py, pz)) } - fn __add__(&self, other: Self) -> Self { - Self(self.0 + other.0) + fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_vec) = other.extract::>() { + Ok(Vector3(self.0 + other_vec.0)) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(self.clone()) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } + } else { + Err(PyTypeError::new_err("Unsupported operand type for +")) + } } - fn __radd__(&self, other: Self) -> Self { - other.__add__(self.clone()) + fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_vec) = other.extract::>() { + Ok(Vector3(other_vec.0 + self.0)) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(self.clone()) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } + } else { + Err(PyTypeError::new_err("Unsupported operand type for +")) + } } /// The dot product /// @@ -268,11 +294,35 @@ pub(crate) mod laddu { fn new(e: Float, px: Float, py: Float, pz: Float) -> Self { Self(nalgebra::Vector4::new(e, px, py, pz)) } - fn __add__(&self, other: Self) -> Self { - Self(self.0 + other.0) + fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_vec) = other.extract::>() { + Ok(Vector4(self.0 + other_vec.0)) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(self.clone()) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } + } else { + Err(PyTypeError::new_err("Unsupported operand type for +")) + } } - fn __radd__(&self, other: Self) -> Self { - other.__add__(self.clone()) + fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_vec) = other.extract::>() { + Ok(Vector4(other_vec.0 + self.0)) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(self.clone()) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } + } else { + Err(PyTypeError::new_err("Unsupported operand type for +")) + } } /// The magnitude of the 4-vector /// @@ -1442,6 +1492,16 @@ pub(crate) mod laddu { Ok(Expression(self.0.clone() + other_aid.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(Expression(self.0.clone() + other_expr.0.clone())) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(Expression(rust::amplitudes::Expression::Amp( + self.0.clone(), + ))) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -1451,6 +1511,16 @@ pub(crate) mod laddu { Ok(Expression(other_aid.0.clone() + self.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(Expression(other_expr.0.clone() + self.0.clone())) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(Expression(rust::amplitudes::Expression::Amp( + self.0.clone(), + ))) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -1464,6 +1534,15 @@ pub(crate) mod laddu { Err(PyTypeError::new_err("Unsupported operand type for *")) } } + fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_aid) = other.extract::>() { + Ok(Expression(other_aid.0.clone() * self.0.clone())) + } else if let Ok(other_expr) = other.extract::() { + Ok(Expression(other_expr.0.clone() * self.0.clone())) + } else { + Err(PyTypeError::new_err("Unsupported operand type for *")) + } + } fn __str__(&self) -> String { format!("{}", self.0) } @@ -1511,6 +1590,14 @@ pub(crate) mod laddu { Ok(Expression(self.0.clone() + other_aid.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(Expression(self.0.clone() + other_expr.0.clone())) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(Expression(self.0.clone())) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -1520,6 +1607,14 @@ pub(crate) mod laddu { Ok(Expression(other_aid.0.clone() + self.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(Expression(other_expr.0.clone() + self.0.clone())) + } else if let Ok(other_int) = other.extract::() { + if other_int == 0 { + Ok(Expression(self.0.clone())) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -1533,6 +1628,15 @@ pub(crate) mod laddu { Err(PyTypeError::new_err("Unsupported operand type for *")) } } + fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_aid) = other.extract::>() { + Ok(Expression(other_aid.0.clone() * self.0.clone())) + } else if let Ok(other_expr) = other.extract::() { + Ok(Expression(other_expr.0.clone() * self.0.clone())) + } else { + Err(PyTypeError::new_err("Unsupported operand type for *")) + } + } fn __str__(&self) -> String { format!("{}", self.0) } @@ -2285,6 +2389,16 @@ pub(crate) mod laddu { Ok(LikelihoodExpression(self.0.clone() + other_aid.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(LikelihoodExpression(self.0.clone() + other_expr.0.clone())) + } else if let Ok(int) = other.extract::() { + if int == 0 { + Ok(LikelihoodExpression( + rust::likelihoods::LikelihoodExpression::Term(self.0.clone()), + )) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -2294,6 +2408,16 @@ pub(crate) mod laddu { Ok(LikelihoodExpression(other_aid.0.clone() + self.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(LikelihoodExpression(other_expr.0.clone() + self.0.clone())) + } else if let Ok(int) = other.extract::() { + if int == 0 { + Ok(LikelihoodExpression( + rust::likelihoods::LikelihoodExpression::Term(self.0.clone()), + )) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -2307,6 +2431,15 @@ pub(crate) mod laddu { Err(PyTypeError::new_err("Unsupported operand type for *")) } } + fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_aid) = other.extract::>() { + Ok(LikelihoodExpression(other_aid.0.clone() * self.0.clone())) + } else if let Ok(other_expr) = other.extract::() { + Ok(LikelihoodExpression(other_expr.0.clone() * self.0.clone())) + } else { + Err(PyTypeError::new_err("Unsupported operand type for *")) + } + } fn __str__(&self) -> String { format!("{}", self.0) } @@ -2322,6 +2455,14 @@ pub(crate) mod laddu { Ok(LikelihoodExpression(self.0.clone() + other_aid.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(LikelihoodExpression(self.0.clone() + other_expr.0.clone())) + } else if let Ok(int) = other.extract::() { + if int == 0 { + Ok(LikelihoodExpression(self.0.clone())) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -2331,6 +2472,14 @@ pub(crate) mod laddu { Ok(LikelihoodExpression(other_aid.0.clone() + self.0.clone())) } else if let Ok(other_expr) = other.extract::() { Ok(LikelihoodExpression(other_expr.0.clone() + self.0.clone())) + } else if let Ok(int) = other.extract::() { + if int == 0 { + Ok(LikelihoodExpression(self.0.clone())) + } else { + Err(PyTypeError::new_err( + "Addition with an integer for this type is only defined for 0", + )) + } } else { Err(PyTypeError::new_err("Unsupported operand type for +")) } @@ -2344,6 +2493,15 @@ pub(crate) mod laddu { Err(PyTypeError::new_err("Unsupported operand type for *")) } } + fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other_aid) = other.extract::>() { + Ok(LikelihoodExpression(self.0.clone() * other_aid.0.clone())) + } else if let Ok(other_expr) = other.extract::() { + Ok(LikelihoodExpression(self.0.clone() * other_expr.0.clone())) + } else { + Err(PyTypeError::new_err("Unsupported operand type for *")) + } + } fn __str__(&self) -> String { format!("{}", self.0) } @@ -2765,6 +2923,41 @@ pub(crate) mod laddu { fn load(path: &str) -> PyResult { Ok(Status(ganesh::Status::load(path)?)) } + #[new] + fn new() -> Self { + Status(ganesh::Status::default()) + } + fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(PyBytes::new_bound( + py, + serialize(&self.0).unwrap().as_slice(), + )) + } + fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> { + *self = Status(deserialize(state.as_bytes()).unwrap()); + Ok(()) + } + /// Converts a Status into a Python dictionary + /// + /// Returns + /// ------- + /// dict + /// + fn as_dict<'py>(&self, py: Python<'py>) -> PyResult> { + let dict = PyDict::new_bound(py); + dict.set_item("x", self.x(py))?; + dict.set_item("err", self.err(py))?; + dict.set_item("x0", self.x0(py))?; + dict.set_item("fx", self.fx())?; + dict.set_item("cov", self.cov(py))?; + dict.set_item("hess", self.hess(py))?; + dict.set_item("message", self.message())?; + dict.set_item("converged", self.converged())?; + dict.set_item("bounds", self.bounds())?; + dict.set_item("n_f_evals", self.n_f_evals())?; + dict.set_item("n_g_evals", self.n_g_evals())?; + Ok(dict) + } } /// A class representing a lower and upper bound on a free parameter @@ -2772,7 +2965,7 @@ pub(crate) mod laddu { #[pyclass] #[derive(Clone)] #[pyo3(name = "Bound")] - struct ParameterBound(ganesh::Bound); + pub(crate) struct ParameterBound(pub(crate) ganesh::Bound); #[pymethods] impl ParameterBound { /// The lower bound @@ -3466,3 +3659,8 @@ impl FromPyObject<'_> for crate::python::laddu::PyObserver { Ok(crate::python::laddu::PyObserver(ob.clone().into())) } } +impl ToPyObject for crate::python::laddu::ParameterBound { + fn to_object(&self, py: Python<'_>) -> PyObject { + PyTuple::new_bound(py, vec![self.0.lower(), self.0.upper()]).into() + } +} diff --git a/src/utils/variables.rs b/src/utils/variables.rs index 3e3a6a3..7898cd3 100644 --- a/src/utils/variables.rs +++ b/src/utils/variables.rs @@ -520,9 +520,9 @@ mod tests { assert_relative_eq!(t.value(&event), tp.value(&event), epsilon = 1e-7); assert_relative_eq!(u.value(&event), -14.4041989, epsilon = 1e-7); assert_relative_eq!(u.value(&event), up.value(&event), epsilon = 1e-7); - let m2_beam = test_event().get_p4_sum(&[0]).m2(); - let m2_recoil = test_event().get_p4_sum(&[1]).m2(); - let m2_res = test_event().get_p4_sum(&[2, 3]).m2(); + let m2_beam = test_event().get_p4_sum([0]).m2(); + let m2_recoil = test_event().get_p4_sum([1]).m2(); + let m2_res = test_event().get_p4_sum([2, 3]).m2(); assert_relative_eq!( s.value(&event) + t.value(&event) + u.value(&event) - m2_beam - m2_recoil - m2_res, 1.00,