Skip to content

Commit

Permalink
optimal-compute-core: add trigonometry
Browse files Browse the repository at this point in the history
and improve related macros.
  • Loading branch information
justinlovinger committed Oct 20, 2024
1 parent d279eea commit 4484b9d
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 83 deletions.
60 changes: 60 additions & 0 deletions optimal-compute-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub trait Computation {
type Dim;
type Item;

// `math`

fn add<Rhs>(self, rhs: Rhs) -> math::Add<Self, Rhs>
where
Self: Sized,
Expand Down Expand Up @@ -111,6 +113,58 @@ pub trait Computation {
math::Abs(self)
}

// `math::trig`

fn sin(self) -> math::Sin<Self>
where
Self: Sized,
math::Sin<Self>: Computation,
{
math::Sin(self)
}

fn cos(self) -> math::Cos<Self>
where
Self: Sized,
math::Cos<Self>: Computation,
{
math::Cos(self)
}

fn tan(self) -> math::Tan<Self>
where
Self: Sized,
math::Tan<Self>: Computation,
{
math::Tan(self)
}

fn asin(self) -> math::Asin<Self>
where
Self: Sized,
math::Asin<Self>: Computation,
{
math::Asin(self)
}

fn acos(self) -> math::Acos<Self>
where
Self: Sized,
math::Acos<Self>: Computation,
{
math::Acos(self)
}

fn atan(self) -> math::Atan<Self>
where
Self: Sized,
math::Atan<Self>: Computation,
{
math::Atan(self)
}

// `cmp`

fn eq<Rhs>(self, rhs: Rhs) -> cmp::Eq<Self, Rhs>
where
Self: Sized,
Expand Down Expand Up @@ -175,6 +229,8 @@ pub trait Computation {
cmp::Not(self)
}

// `enumerate`

fn enumerate<F>(self, f: F) -> enumerate::Enumerate<Self, F>
where
Self: Sized,
Expand All @@ -183,6 +239,8 @@ pub trait Computation {
enumerate::Enumerate { child: self, f }
}

// `sum`

fn sum(self) -> sum::Sum<Self>
where
Self: Sized,
Expand All @@ -191,6 +249,8 @@ pub trait Computation {
sum::Sum(self)
}

// `zip`

fn zip<Rhs>(self, rhs: Rhs) -> zip::Zip<Self, Rhs>
where
Self: Sized,
Expand Down
129 changes: 91 additions & 38 deletions optimal-compute-core/src/math.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use core::{fmt, ops};

use num_traits::Signed;
use paste::paste;

use crate::{impl_core_ops, impl_display_for_inline_binary, Computation, ComputationFn};

pub use self::same_or_zero::*;
pub use self::{same_or_zero::*, trig::*};

mod same_or_zero {
use crate::peano::{Suc, Zero};
Expand Down Expand Up @@ -36,21 +35,24 @@ macro_rules! impl_binary_op {
impl_binary_op!($op, ops);
};
( $op:ident, $package:ident ) => {
impl_binary_op!($op, $package, $op);
};
( $op:ident, $package:ident, $bound:ident ) => {
paste! {
#[derive(Clone, Copy, Debug)]
pub struct $op<A, B>(pub A, pub B)
where
Self: Computation;

impl<A, B> Computation for $op<A, B>
impl<A, B, ADim, AItem> Computation for $op<A, B>
where
A: Computation,
A: Computation<Dim = ADim, Item = AItem>,
B: Computation,
A::Dim: SameOrZero<B::Dim>,
A::Item: $package::$op<B::Item>,
ADim: SameOrZero<B::Dim>,
AItem: $package::$bound<B::Item>,
{
type Dim = <A::Dim as SameOrZero<B::Dim>>::Max;
type Item = <A::Item as $package::$op<B::Item>>::Output;
type Dim = ADim::Max;
type Item = AItem::Output;
}

impl<A, B> ComputationFn for $op<A, B>
Expand All @@ -74,20 +76,26 @@ macro_rules! impl_unary_op {
impl_unary_op!($op, ops);
};
( $op:ident, $package:ident ) => {
impl_unary_op!($op, $package, $op);
};
( $op:ident, $package:ident, $bound:ident ) => {
impl_unary_op!($op, $package, $bound, Item::Output);
};
( $op:ident, $package:ident, $bound:ident, Item $( :: $Output:ident )? ) => {
paste! {
#[derive(Clone, Copy, Debug)]
pub struct $op<A>(pub A)
where
Self: Computation;


impl<A> Computation for $op<A>
impl<A, Item> Computation for $op<A>
where
A: Computation,
A::Item: $package::$op,
A: Computation<Item = Item>,
Item: $package::$bound,
{
type Dim = A::Dim;
type Item = <A::Item as $package::$op>::Output;
type Item = Item $( ::$Output )?;
}

impl<A> ComputationFn for $op<A>
Expand All @@ -111,6 +119,7 @@ impl_binary_op!(Mul);
impl_binary_op!(Div);
impl_binary_op!(Pow, num_traits);
impl_unary_op!(Neg);
impl_unary_op!(Abs, num_traits, Signed, Item);

impl_display_for_inline_binary!(Add, "+");
impl_display_for_inline_binary!(Sub, "-");
Expand All @@ -128,32 +137,6 @@ where
}
}

#[derive(Clone, Copy, Debug)]
pub struct Abs<A>(pub A)
where
Self: Computation;

impl<A> Computation for Abs<A>
where
A: Computation,
A::Item: Signed,
{
type Dim = A::Dim;
type Item = A::Item;
}

impl<A> ComputationFn for Abs<A>
where
Self: Computation,
A: ComputationFn,
{
fn args(&self) -> crate::Args {
self.0.args()
}
}

impl_core_ops!(Abs<A>);

impl<A> fmt::Display for Abs<A>
where
Self: Computation,
Expand All @@ -164,6 +147,42 @@ where
}
}

mod trig {
use num_traits::real;

use super::*;

impl_unary_op!(Sin, real, Real, Item);
impl_unary_op!(Cos, real, Real, Item);
impl_unary_op!(Tan, real, Real, Item);
impl_unary_op!(Asin, real, Real, Item);
impl_unary_op!(Acos, real, Real, Item);
impl_unary_op!(Atan, real, Real, Item);

macro_rules! impl_display {
( $op:ident ) => {
paste::paste! {
impl<A> fmt::Display for $op<A>
where
Self: Computation,
A: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}.{}()", self.0, stringify!([<$op:lower>]))
}
}
}
};
}

impl_display!(Sin);
impl_display!(Cos);
impl_display!(Tan);
impl_display!(Asin);
impl_display!(Acos);
impl_display!(Atan);
}

#[cfg(test)]
mod tests {
use proptest::prelude::*;
Expand Down Expand Up @@ -217,4 +236,38 @@ mod tests {
fn abs_should_display(x: i32) {
prop_assert_eq!(val!(x).abs().to_string(), format!("{}.abs()", val!(x)));
}

mod trig {
use super::*;

#[proptest]
fn sin_should_display(x: f32) {
prop_assert_eq!(val!(x).sin().to_string(), format!("{}.sin()", val!(x)));
}

#[proptest]
fn cos_should_display(x: f32) {
prop_assert_eq!(val!(x).cos().to_string(), format!("{}.cos()", val!(x)));
}

#[proptest]
fn tan_should_display(x: f32) {
prop_assert_eq!(val!(x).tan().to_string(), format!("{}.tan()", val!(x)));
}

#[proptest]
fn asin_should_display(x: f32) {
prop_assert_eq!(val!(x).asin().to_string(), format!("{}.asin()", val!(x)));
}

#[proptest]
fn acos_should_display(x: f32) {
prop_assert_eq!(val!(x).acos().to_string(), format!("{}.acos()", val!(x)));
}

#[proptest]
fn atan_should_display(x: f32) {
prop_assert_eq!(val!(x).atan().to_string(), format!("{}.atan()", val!(x)));
}
}
}
Loading

0 comments on commit 4484b9d

Please sign in to comment.