Skip to content
This repository has been archived by the owner on Feb 8, 2022. It is now read-only.

Add Signed and TruncatedDivision typeclasses #247

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions core/src/main/scala/algebra/instances/bigInt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@ package algebra
package instances

import algebra.ring._
import cats.kernel.instances.BigIntOrder
import cats.kernel.{Hash, UnboundedEnumerable}

package object bigInt extends BigIntInstances

trait BigIntInstances extends cats.kernel.instances.BigIntInstances {
implicit val bigIntAlgebra: BigIntAlgebra =
private val instance: TruncatedDivision[BigInt] with CommutativeRing[BigInt] =
new BigIntAlgebra

implicit def bigIntAlgebra: CommutativeRing[BigInt] = instance

implicit def bigIntTruncatedDivision: TruncatedDivision[BigInt] = instance
}

class BigIntAlgebra extends CommutativeRing[BigInt] with Serializable {
class BigIntAlgebra extends CommutativeRing[BigInt] with TruncatedDivision.forCommutativeRing[BigInt] with Serializable {

override def compare(x: BigInt, y: BigInt): Int = x.compare(y)

val zero: BigInt = BigInt(0)
val one: BigInt = BigInt(1)
Expand All @@ -25,4 +33,8 @@ class BigIntAlgebra extends CommutativeRing[BigInt] with Serializable {

override def fromInt(n: Int): BigInt = BigInt(n)
override def fromBigInt(n: BigInt): BigInt = n

def tquot(x: BigInt, y: BigInt): BigInt = x / y
def tmod(x: BigInt, y: BigInt): BigInt = x % y
override def tquotmod(x: BigInt, y: BigInt): (BigInt, BigInt) = x /% y
}
141 changes: 141 additions & 0 deletions core/src/main/scala/algebra/ring/Signed.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package algebra
package ring

import scala.{specialized => sp}

/**
* A trait that expresses the existence of signs and absolute values on linearly ordered additive commutative monoids
* (i.e. types with addition and a zero).
*
* The following laws holds:
*
* (1) if `a <= b` then `a + c <= b + c` (linear order),
* (2) `signum(x) = -1` if `x < 0`, `signum(x) = 1` if `x > 0`, `signum(x) = 0` otherwise,
*
* Negative elements only appear when the scalar is taken from a additive abelian group. Then:
*
* (3) `abs(x) = -x` if `x < 0`, or `x` otherwise,
*
* Laws (1) and (2) lead to the triange inequality:
*
* (4) `abs(a + b) <= abs(a) + abs(b)`
*
* Signed should never be extended in implementations, rather the [[Signed.forAdditiveCommutativeMonoid]] and
* [[Signed.forAdditiveCommutativeGroup subtraits]].
*
* It's better to have the Eq/PartialOrder/Order/Signed hierarchy separate from the Ring hierarchy, so that
* we do not end up with duplicate implicits. At the same time, we cannot use self-types to express
* the constraint that Signed must be an [[AdditiveCommutativeMonoid]], due to interaction with specialization.
*/
trait Signed[@sp(Byte, Short, Int, Long, Float, Double) A] extends Any with Order[A] {

/**
* Returns Zero if `a` is 0, Positive if `a` is positive, and Negative is `a` is negative.
*/
def sign(a: A): Signed.Sign = Signed.Sign(signum(a))

/**
* Returns 0 if `a` is 0, 1 if `a` is positive, and -1 is `a` is negative.
*/
def signum(a: A): Int

/**
* An idempotent function that ensures an object has a non-negative sign.
*/
def abs(a: A): A

def isSignZero(a: A): Boolean = signum(a) == 0
def isSignPositive(a: A): Boolean = signum(a) > 0
def isSignNegative(a: A): Boolean = signum(a) < 0

def isSignNonZero(a: A): Boolean = signum(a) != 0
def isSignNonPositive(a: A): Boolean = signum(a) <= 0
def isSignNonNegative(a: A): Boolean = signum(a) >= 0
}

trait SignedFunctions[S[T] <: Signed[T]] extends cats.kernel.OrderFunctions[S] {
def sign[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Signed.Sign =
ev.sign(a)
def signum[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Int =
ev.signum(a)
def abs[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): A =
ev.abs(a)
def isSignZero[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Boolean =
ev.isSignZero(a)
def isSignPositive[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Boolean =
ev.isSignPositive(a)
def isSignNegative[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Boolean =
ev.isSignNegative(a)
def isSignNonZero[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Boolean =
ev.isSignNonZero(a)
def isSignNonPositive[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Boolean =
ev.isSignNonPositive(a)
def isSignNonNegative[@sp(Int, Long, Float, Double) A](a: A)(implicit ev: S[A]): Boolean =
ev.isSignNonNegative(a)
}

object Signed extends SignedFunctions[Signed] {

/** Signed implementation for additive commutative monoids */
trait forAdditiveCommutativeMonoid[A] extends Any with Signed[A] with AdditiveCommutativeMonoid[A] {
def signum(a: A): Int = {
val c = compare(a, zero)
if (c < 0) -1
else if (c > 0) 1
else 0
}
}

/** Signed implementation for additive commutative groups */
trait forAdditiveCommutativeGroup[A] extends Any with forAdditiveCommutativeMonoid[A] with AdditiveCommutativeGroup[A] {
def abs(a: A): A = if (compare(a, zero) < 0) negate(a) else a
}

def apply[A](implicit s: Signed[A]): Signed[A] = s

/**
* A simple ADT representing the `Sign` of an object.
*/
sealed abstract class Sign(val toInt: Int) {
def unary_- : Sign = this match {
case Positive => Negative
case Negative => Positive
case Zero => Zero
}

def *(that: Sign): Sign = Sign(this.toInt * that.toInt)

def **(that: Int): Sign = this match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment that this is exponentiation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

case Positive => Positive
case Zero if that == 0 => Positive
case Zero => Zero
case Negative if (that % 2) == 0 => Positive
case Negative => Negative
}
}

case object Zero extends Sign(0)
case object Positive extends Sign(1)
case object Negative extends Sign(-1)

object Sign {
implicit def sign2int(s: Sign): Int = s.toInt

def apply(i: Int): Sign =
if (i == 0) Zero else if (i > 0) Positive else Negative

private val instance: CommutativeMonoid[Sign] with MultiplicativeCommutativeMonoid[Sign] with Eq[Sign] =
new CommutativeMonoid[Sign] with MultiplicativeCommutativeMonoid[Sign] with Eq[Sign] {
def eqv(x: Sign, y: Sign): Boolean = x == y
def empty: Sign = Positive
def combine(x: Sign, y: Sign): Sign = x*y
def one: Sign = Positive
def times(x: Sign, y: Sign): Sign = x*y
}

implicit final def signMultiplicativeMonoid: MultiplicativeCommutativeMonoid[Sign] = instance
implicit final def signMonoid: CommutativeMonoid[Sign] = instance
implicit final def signEq: Eq[Sign] = instance
}

}
87 changes: 87 additions & 0 deletions core/src/main/scala/algebra/ring/TruncatedDivision.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package algebra
package ring

import scala.{specialized => sp}

/**
* Division and modulus for computer scientists
* taken from https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/divmodnote-letter.pdf
*
* For two numbers x (dividend) and y (divisor) on an ordered ring with y != 0,
* there exists a pair of numbers q (quotient) and r (remainder)
* such that these laws are satisfied:
*
* (1) q is an integer
* (2) x = y * q + r (division rule)
* (3) |r| < |y|,
* (4t) r = 0 or sign(r) = sign(x),
* (4f) r = 0 or sign(r) = sign(y).
*
* where sign is the sign function, and the absolute value
* function |x| is defined as |x| = x if x >=0, and |x| = -x otherwise.
*
* We define functions tmod and tquot such that:
* q = tquot(x, y) and r = tmod(x, y) obey rule (4t),
* (which truncates effectively towards zero)
* and functions fmod and fquot such that:
* q = fquot(x, y) and r = fmod(x, y) obey rule (4f)
* (which floors the quotient and effectively rounds towards negative infinity).
*
* Law (4t) corresponds to ISO C99 and Haskell's quot/rem.
* Law (4f) is described by Knuth and used by Haskell,
* and fmod corresponds to the REM function of the IEEE floating-point standard.
*/
trait TruncatedDivision[@sp(Byte, Short, Int, Long, Float, Double) A] extends Any with Signed[A] {
def tquot(x: A, y: A): A
def tmod(x: A, y: A): A
def tquotmod(x: A, y: A): (A, A) = (tquot(x, y), tmod(x, y))

def fquot(x: A, y: A): A
def fmod(x: A, y: A): A
def fquotmod(x: A, y: A): (A, A) = (fquot(x, y), fmod(x, y))
}

trait TruncatedDivisionFunctions[S[T] <: TruncatedDivision[T]] extends SignedFunctions[S] {
def tquot[@sp(Int, Long, Float, Double) A](x: A, y: A)(implicit ev: TruncatedDivision[A]): A =
ev.tquot(x, y)
def tmod[@sp(Int, Long, Float, Double) A](x: A, y: A)(implicit ev: TruncatedDivision[A]): A =
ev.tmod(x, y)
def tquotmod[@sp(Int, Long, Float, Double) A](x: A, y: A)(implicit ev: TruncatedDivision[A]): (A, A) =
ev.tquotmod(x, y)
def fquot[@sp(Int, Long, Float, Double) A](x: A, y: A)(implicit ev: TruncatedDivision[A]): A =
ev.fquot(x, y)
def fmod[@sp(Int, Long, Float, Double) A](x: A, y: A)(implicit ev: TruncatedDivision[A]): A =
ev.fmod(x, y)
def fquotmod[@sp(Int, Long, Float, Double) A](x: A, y: A)(implicit ev: TruncatedDivision[A]): (A, A) =
ev.fquotmod(x, y)
}

object TruncatedDivision extends TruncatedDivisionFunctions[TruncatedDivision] {
trait forCommutativeRing[@sp(Byte, Short, Int, Long, Float, Double) A]
extends Any
with TruncatedDivision[A]
with Signed.forAdditiveCommutativeGroup[A]
with CommutativeRing[A] { self =>

def fmod(x: A, y: A): A = {
val tm = tmod(x, y)
if (signum(tm) == -signum(y)) plus(tm, y) else tm
}

def fquot(x: A, y: A): A = {
val (tq, tm) = tquotmod(x, y)
if (signum(tm) == -signum(y)) minus(tq, one) else tq
}

override def fquotmod(x: A, y: A): (A, A) = {
val (tq, tm) = tquotmod(x, y)
val signsDiffer = signum(tm) == -signum(y)
val fq = if (signsDiffer) minus(tq, one) else tq
val fm = if (signsDiffer) plus(tm, y) else tm
(fq, fm)
}

}

def apply[A](implicit ev: TruncatedDivision[A]): TruncatedDivision[A] = ev
}
71 changes: 71 additions & 0 deletions laws/shared/src/main/scala/algebra/laws/OrderLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import org.scalacheck.Prop._

import cats.kernel.instances.all._

import algebra.ring.{CommutativeRing, Signed, TruncatedDivision}

object OrderLaws {
def apply[A: Eq: Arbitrary: Cogen]: OrderLaws[A] =
new OrderLaws[A] {
Expand Down Expand Up @@ -112,6 +114,75 @@ trait OrderLaws[A] extends Laws {
}
)

def signed(implicit A: Signed[A]) = new OrderProperties(
name = "signed",
parent = Some(order),
"abs non-negative" -> forAll((x: A) => A.sign(A.abs(x)) != Signed.Negative),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we require that abs(x) >= x?

Seems natural to me, but I don't see that is implied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

"signum returns -1/0/1" -> forAll((x: A) => A.signum(A.abs(x)) <= 1),
"signum is sign.toInt" -> forAll((x: A) => A.signum(x) == A.sign(x).toInt)
)

def truncatedDivision(implicit ring: CommutativeRing[A], A: TruncatedDivision[A]) = new DefaultRuleSet(
name = "truncatedDivision",
parent = Some(signed),
"division rule (tquotmod)" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
val (q, r) = A.tquotmod(x, y)
x ?== ring.plus(ring.times(y, q), r)
}
},
"division rule (fquotmod)" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
val (q, r) = A.fquotmod(x, y)
x ?== ring.plus(ring.times(y, q), r)
}
},
"|r| < |y| (tmod)" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
val r = A.tmod(x, y)
A.lt(A.abs(r), A.abs(y))
}
},
"|r| < |y| (fmod)" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
val r = A.fmod(x, y)
A.lt(A.abs(r), A.abs(y))
}
},
"r = 0 or sign(r) = sign(x) (tmod)" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
val r = A.tmod(x, y)
A.isSignZero(r) || (A.sign(r) ?== A.sign(x))
}
},
"r = 0 or sign(r) = sign(y) (fmod)" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
val r = A.fmod(x, y)
A.isSignZero(r) || (A.sign(r) ?== A.sign(y))
}
},
"tquot" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
A.tquotmod(x, y)._1 ?== A.tquot(x, y)
}
},
"tmod" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
A.tquotmod(x, y)._2 ?== A.tmod(x, y)
}
},
"fquot" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
A.fquotmod(x, y)._1 ?== A.fquot(x, y)
}
},
"fmod" -> forAll { (x: A, y: A) =>
A.isSignNonZero(y) ==> {
A.fquotmod(x, y)._2 ?== A.fmod(x, y)
}
}
)

class OrderProperties(
name: String,
parent: Option[RuleSet],
Expand Down
2 changes: 2 additions & 0 deletions laws/shared/src/test/scala/algebra/laws/LawTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class LawTests extends munit.DisciplineSuite {
checkAll("Long", RingLaws[Long].commutativeRing)
checkAll("Long", LatticeLaws[Long].boundedDistributiveLattice)

// catsKernelStdOrderForBigInt
checkAll("BigInt", OrderLaws[BigInt].truncatedDivision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add implementations for the specialized things?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as for the intermediate ring structures-- the instances are only approximations due to the type limited range.

checkAll("BigInt", RingLaws[BigInt].commutativeRing)

checkAll("FPApprox[Float]", RingLaws[FPApprox[Float]].field)
Expand Down