Skip to content

Commit

Permalink
feat: export Status and Bound structs from ganesh as PyO3 objec…
Browse files Browse the repository at this point in the history
…ts and update `minimize` method accordingly

Currently, `minimize` has a `DebugObserver`, and a cleaner opt-in solution will have to be made in the future before PR
  • Loading branch information
denehoffman committed Oct 20, 2024
1 parent c4e74e4 commit 207f026
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ rand = "0.8.5"
rayon = { version = "1.10.0", optional = true }
pyo3 = { version = "0.22.5", optional = true, features = ["num-complex"] }
numpy = { version = "0.22.0", optional = true, features = ["nalgebra"] }
ganesh = "0.11.2"
ganesh = "0.11.3"
thiserror = "1.0.64"
shellexpand = "3.1.0"

Expand Down
16 changes: 15 additions & 1 deletion python/laddu/amplitudes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from laddu.amplitudes import breit_wigner, common, kmatrix, ylm, zlm
from laddu.laddu import NLL, Amplitude, AmplitudeID, Evaluator, Expression, Manager, ParameterLike, constant, parameter
from laddu.laddu import (
NLL,
Amplitude,
AmplitudeID,
Bound,
Evaluator,
Expression,
Manager,
ParameterLike,
Status,
constant,
parameter,
)

__all__ = [
"AmplitudeID",
Expand All @@ -16,4 +28,6 @@
"zlm",
"breit_wigner",
"kmatrix",
"Status",
"Bound",
]
31 changes: 27 additions & 4 deletions python/laddu/amplitudes/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class Evaluator:
def deactivate(self, name: str | list[str]) -> None: ...
def deactivate_all(self) -> None: ...
def isolate(self, name: str | list[str]) -> None: ...
def evaluate(self, expression: Expression, parameters: list[float]) -> npt.NDArray[np.complex128]: ...
def evaluate(
self, expression: Expression, parameters: list[float] | npt.NDArray[np.float64]
) -> npt.NDArray[np.complex128]: ...

class NLL:
parameters: list[str]
Expand All @@ -47,11 +49,30 @@ class NLL:
def deactivate(self, name: str | list[str]) -> None: ...
def deactivate_all(self) -> None: ...
def isolate(self, name: str | list[str]) -> None: ...
def evaluate(self, expression: Expression, parameters: list[float]) -> float: ...
def project(self, expression: Expression, parameters: list[float]) -> npt.NDArray[np.float64]: ...
def evaluate(self, expression: Expression, parameters: list[float] | npt.NDArray[np.float64]) -> float: ...
def project(
self, expression: Expression, parameters: list[float] | npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]: ...
def minimize(
self, expression: Expression, p0: list[float], bounds: list[tuple[float | None, float | None]] | None = None
) -> npt.NDArray[np.float64]: ...
) -> Status: ...

class Status:
x: npt.NDArray[np.float64]
err: npt.NDArray[np.float64] | None
x0: npt.NDArray[np.float64]
fx: float
cov: npt.NDArray[np.float64] | None
hess: npt.NDArray[np.float64] | None
message: str
converged: bool
bounds: list[Bound] | None
n_f_evals: int
n_g_evals: int

class Bound:
lower: float
upper: float

__all__ = [
"AmplitudeID",
Expand All @@ -68,4 +89,6 @@ __all__ = [
"zlm",
"breit_wigner",
"kmatrix",
"Status",
"Bound",
]
7 changes: 5 additions & 2 deletions src/amplitudes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,12 @@ impl NLL {
p0: &[Float],
bounds: Option<Vec<(Float, Float)>>,
) -> Status<Float> {
let mut m = Minimizer::new(LBFGSB::default(), self.parameters().len()).with_bounds(bounds);
let mut m = Minimizer::new(LBFGSB::default(), self.parameters().len())
.with_bounds(bounds)
.with_observer(DebugObserver);
let mut expression = expression.clone();
m.minimize(self, p0, &mut expression).unwrap();
m.minimize(self, p0, &mut expression)
.unwrap_or_else(|never| match never {});
m.status
}
}
106 changes: 100 additions & 6 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod laddu {
use crate::utils::vectors::{FourMomentum, FourVector, ThreeMomentum, ThreeVector};
use crate::Float;
use num::Complex;
use numpy::PyArray1;
use numpy::{PyArray1, PyArray2};
use pyo3::exceptions::{PyIndexError, PyTypeError};
use pyo3::types::PyList;

Expand Down Expand Up @@ -652,13 +652,12 @@ mod laddu {
PyArray1::from_slice_bound(py, &self.0.project(&expression.0, &parameters))
}
#[pyo3(signature = (expression, p0, bounds = None))]
fn minimize<'py>(
fn minimize(
&self,
py: Python<'py>,
expression: &Expression,
p0: Vec<Float>,
bounds: Option<Vec<(Option<Float>, Option<Float>)>>,
) -> Bound<'py, PyArray1<Float>> {
) -> Status {
let bounds = bounds.map(|bounds_vec| {
bounds_vec
.iter()
Expand All @@ -671,8 +670,103 @@ mod laddu {
.collect()
});
let status = self.0.minimize(&expression.0, &p0, bounds);
println!("{}", status);
PyArray1::from_slice_bound(py, status.x.as_slice())
Status(status)
}
}

#[pyclass]
#[derive(Clone)]
struct Status(ganesh::Status<Float>);
#[pymethods]
impl Status {
#[getter]
fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
PyArray1::from_slice_bound(py, self.0.x.as_slice())
}
#[getter]
fn err<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray1<Float>>> {
self.0
.err
.clone()
.map(|err| PyArray1::from_slice_bound(py, err.as_slice()))
}
#[getter]
fn x0<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
PyArray1::from_slice_bound(py, self.0.x0.as_slice())
}
#[getter]
fn fx(&self) -> Float {
self.0.fx
}
#[getter]
fn cov<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
self.0.cov.clone().map(|cov| {
PyArray2::from_vec2_bound(
py,
&cov.row_iter()
.map(|row| row.iter().cloned().collect())
.collect::<Vec<Vec<Float>>>(),
)
.unwrap()
})
}
#[getter]
fn hess<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
self.0.hess.clone().map(|hess| {
PyArray2::from_vec2_bound(
py,
&hess
.row_iter()
.map(|row| row.iter().cloned().collect())
.collect::<Vec<Vec<Float>>>(),
)
.unwrap()
})
}
#[getter]
fn message(&self) -> String {
self.0.message.clone()
}
#[getter]
fn converged(&self) -> bool {
self.0.converged
}
#[getter]
fn bounds(&self) -> Option<Vec<ParameterBound>> {
self.0
.bounds
.clone()
.map(|bounds| bounds.iter().map(|bound| ParameterBound(*bound)).collect())
}
#[getter]
fn n_f_evals(&self) -> usize {
self.0.n_f_evals
}
#[getter]
fn n_g_evals(&self) -> usize {
self.0.n_g_evals
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}

#[pyclass]
#[derive(Clone)]
#[pyo3(name = "Bound")]
struct ParameterBound(ganesh::Bound<Float>);
#[pymethods]
impl ParameterBound {
#[getter]
fn lower(&self) -> Float {
self.0.lower()
}
#[getter]
fn upper(&self) -> Float {
self.0.upper()
}
}

Expand Down

0 comments on commit 207f026

Please sign in to comment.