Skip to content

Commit

Permalink
feat: add no-op implementations for adding 0 to add-able types
Browse files Browse the repository at this point in the history
This allows "sum" to work.
  • Loading branch information
denehoffman committed Nov 19, 2024
1 parent e382a0d commit a8f6323
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 18 deletions.
10 changes: 5 additions & 5 deletions python/laddu/amplitudes/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ class AmplitudeID:
def real(self) -> Expression: ...
def imag(self) -> Expression: ...
def norm_sqr(self) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression) -> Expression: ...
def __radd__(self, other: AmplitudeID | Expression) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __radd__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __mul__(self, other: AmplitudeID | Expression) -> Expression: ...
def __rmul__(self, other: AmplitudeID | Expression) -> Expression: ...

class Expression:
def real(self) -> Expression: ...
def imag(self) -> Expression: ...
def norm_sqr(self) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression) -> Expression: ...
def __radd__(self, other: AmplitudeID | Expression) -> Expression: ...
def __add__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __radd__(self, other: AmplitudeID | Expression | int) -> Expression: ...
def __mul__(self, other: AmplitudeID | Expression) -> Expression: ...
def __rmull__(self, other: AmplitudeID | Expression) -> Expression: ...
def __rmul__(self, other: AmplitudeID | Expression) -> Expression: ...

class Amplitude: ...

Expand Down
8 changes: 4 additions & 4 deletions python/laddu/likelihoods/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ from laddu.amplitudes import Expression, Manager
from laddu.data import Dataset

class LikelihoodID:
def __add__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __radd__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __add__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __radd__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __mul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __rmul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...

class LikelihoodExpression:
def __add__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __radd__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __add__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __radd__(self, other: LikelihoodID | LikelihoodExpression | int) -> LikelihoodExpression: ...
def __mul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...
def __rmul__(self, other: LikelihoodID | LikelihoodExpression) -> LikelihoodExpression: ...

Expand Down
3 changes: 2 additions & 1 deletion python/laddu/utils/vectors/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class Vector3:
py: float
pz: float
def __init__(self, px: float, py: float, pz: float): ...
def __add__(self, other: Vector3) -> Vector3: ...
def __add__(self, other: Vector3 | int) -> Vector3: ...
def __radd__(self, other: Vector3 | int) -> Vector3: ...
def dot(self, other: Vector3) -> float: ...
def cross(self, other: Vector3) -> Vector3: ...
def to_numpy(self) -> npt.NDArray[np.float64]: ...
Expand Down
172 changes: 164 additions & 8 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,35 @@ pub(crate) mod laddu {
fn new(px: Float, py: Float, pz: Float) -> Self {
Self(nalgebra::Vector3::new(px, py, pz))
}
fn __add__(&self, other: Self) -> Self {
Self(self.0 + other.0)
fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(other_vec) = other.extract::<PyRef<Vector3>>() {
Ok(Vector3(self.0 + other_vec.0))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(self.clone())
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __radd__(&self, other: Self) -> Self {
other.__add__(self.clone())
fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(other_vec) = other.extract::<PyRef<Vector3>>() {
Ok(Vector3(other_vec.0 + self.0))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(self.clone())
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
/// The dot product
///
Expand Down Expand Up @@ -270,11 +294,35 @@ pub(crate) mod laddu {
fn new(e: Float, px: Float, py: Float, pz: Float) -> Self {
Self(nalgebra::Vector4::new(e, px, py, pz))
}
fn __add__(&self, other: Self) -> Self {
Self(self.0 + other.0)
fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(other_vec) = other.extract::<PyRef<Vector4>>() {
Ok(Vector4(self.0 + other_vec.0))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(self.clone())
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __radd__(&self, other: Self) -> Self {
other.__add__(self.clone())
fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(other_vec) = other.extract::<PyRef<Vector4>>() {
Ok(Vector4(other_vec.0 + self.0))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(self.clone())
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
/// The magnitude of the 4-vector
///
Expand Down Expand Up @@ -1444,6 +1492,16 @@ pub(crate) mod laddu {
Ok(Expression(self.0.clone() + other_aid.0.clone()))
} else if let Ok(other_expr) = other.extract::<Expression>() {
Ok(Expression(self.0.clone() + other_expr.0.clone()))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(Expression(rust::amplitudes::Expression::Amp(
self.0.clone(),
)))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -1453,6 +1511,16 @@ pub(crate) mod laddu {
Ok(Expression(other_aid.0.clone() + self.0.clone()))
} else if let Ok(other_expr) = other.extract::<Expression>() {
Ok(Expression(other_expr.0.clone() + self.0.clone()))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(Expression(rust::amplitudes::Expression::Amp(
self.0.clone(),
)))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -1466,6 +1534,15 @@ pub(crate) mod laddu {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<Expression> {
if let Ok(other_aid) = other.extract::<PyRef<AmplitudeID>>() {
Ok(Expression(other_aid.0.clone() * self.0.clone()))
} else if let Ok(other_expr) = other.extract::<Expression>() {
Ok(Expression(other_expr.0.clone() * self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __str__(&self) -> String {
format!("{}", self.0)
}
Expand Down Expand Up @@ -1513,6 +1590,14 @@ pub(crate) mod laddu {
Ok(Expression(self.0.clone() + other_aid.0.clone()))
} else if let Ok(other_expr) = other.extract::<Expression>() {
Ok(Expression(self.0.clone() + other_expr.0.clone()))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(Expression(self.0.clone()))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -1522,6 +1607,14 @@ pub(crate) mod laddu {
Ok(Expression(other_aid.0.clone() + self.0.clone()))
} else if let Ok(other_expr) = other.extract::<Expression>() {
Ok(Expression(other_expr.0.clone() + self.0.clone()))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(Expression(self.0.clone()))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -1535,6 +1628,15 @@ pub(crate) mod laddu {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<Expression> {
if let Ok(other_aid) = other.extract::<PyRef<AmplitudeID>>() {
Ok(Expression(other_aid.0.clone() * self.0.clone()))
} else if let Ok(other_expr) = other.extract::<Expression>() {
Ok(Expression(other_expr.0.clone() * self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __str__(&self) -> String {
format!("{}", self.0)
}
Expand Down Expand Up @@ -2287,6 +2389,16 @@ pub(crate) mod laddu {
Ok(LikelihoodExpression(self.0.clone() + other_aid.0.clone()))
} else if let Ok(other_expr) = other.extract::<LikelihoodExpression>() {
Ok(LikelihoodExpression(self.0.clone() + other_expr.0.clone()))
} else if let Ok(int) = other.extract::<usize>() {
if int == 0 {
Ok(LikelihoodExpression(
rust::likelihoods::LikelihoodExpression::Term(self.0.clone()),
))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -2296,6 +2408,16 @@ pub(crate) mod laddu {
Ok(LikelihoodExpression(other_aid.0.clone() + self.0.clone()))
} else if let Ok(other_expr) = other.extract::<LikelihoodExpression>() {
Ok(LikelihoodExpression(other_expr.0.clone() + self.0.clone()))
} else if let Ok(int) = other.extract::<usize>() {
if int == 0 {
Ok(LikelihoodExpression(
rust::likelihoods::LikelihoodExpression::Term(self.0.clone()),
))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -2309,6 +2431,15 @@ pub(crate) mod laddu {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<LikelihoodExpression> {
if let Ok(other_aid) = other.extract::<PyRef<LikelihoodID>>() {
Ok(LikelihoodExpression(other_aid.0.clone() * self.0.clone()))
} else if let Ok(other_expr) = other.extract::<LikelihoodExpression>() {
Ok(LikelihoodExpression(other_expr.0.clone() * self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __str__(&self) -> String {
format!("{}", self.0)
}
Expand All @@ -2324,6 +2455,14 @@ pub(crate) mod laddu {
Ok(LikelihoodExpression(self.0.clone() + other_aid.0.clone()))
} else if let Ok(other_expr) = other.extract::<LikelihoodExpression>() {
Ok(LikelihoodExpression(self.0.clone() + other_expr.0.clone()))
} else if let Ok(int) = other.extract::<usize>() {
if int == 0 {
Ok(LikelihoodExpression(self.0.clone()))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -2333,6 +2472,14 @@ pub(crate) mod laddu {
Ok(LikelihoodExpression(other_aid.0.clone() + self.0.clone()))
} else if let Ok(other_expr) = other.extract::<LikelihoodExpression>() {
Ok(LikelihoodExpression(other_expr.0.clone() + self.0.clone()))
} else if let Ok(int) = other.extract::<usize>() {
if int == 0 {
Ok(LikelihoodExpression(self.0.clone()))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
Expand All @@ -2346,6 +2493,15 @@ pub(crate) mod laddu {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<LikelihoodExpression> {
if let Ok(other_aid) = other.extract::<PyRef<LikelihoodID>>() {
Ok(LikelihoodExpression(self.0.clone() * other_aid.0.clone()))
} else if let Ok(other_expr) = other.extract::<LikelihoodExpression>() {
Ok(LikelihoodExpression(self.0.clone() * other_expr.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __str__(&self) -> String {
format!("{}", self.0)
}
Expand Down

0 comments on commit a8f6323

Please sign in to comment.