Skip to content

Commit

Permalink
Change return types
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurenceWarne committed Mar 20, 2024
1 parent effe0b3 commit 55846fe
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 72 deletions.
92 changes: 46 additions & 46 deletions core/src/main/scala/spire/math/Interval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
case Below(upper, uf) =>
List(Above(upper, upperFlagToLower(reverseUpperFlag(uf))))
case Point(p) =>
List(Interval.below(p), Interval.above(p))
List(Interval.below[A](p), Interval.above[A](p))
case Bounded(lower, upper, flags) =>
val lx = lowerFlagToUpper(reverseLowerFlag(lowerFlag(flags)))
val ux = upperFlagToLower(reverseUpperFlag(upperFlag(flags)))
Expand All @@ -251,7 +251,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
}

def split(t: A)(implicit o: Order[A]): (Interval[A], Interval[A]) =
(this.intersect(Interval.below(t)), this.intersect(Interval.above(t)))
(this.intersect(Interval.below[A](t)), this.intersect(Interval.above[A](t)))

def splitAtZero(implicit o: Order[A], ev: AdditiveMonoid[A]): (Interval[A], Interval[A]) =
split(ev.zero)
Expand Down Expand Up @@ -295,7 +295,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
else if (upper > x) Bounded(m.zero, upper, upperFlag(fs))
else Bounded(m.zero, x, lowerFlagToUpper(fs) & upperFlag(fs))
case _ => // Above or Below
Interval.atOrAbove(m.zero)
Interval.atOrAbove[A](m.zero)
}
} else if (hasBelow(m.zero)) {
-this
Expand Down Expand Up @@ -565,7 +565,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
if (k < 0) {
throw new IllegalArgumentException(s"negative exponent: $k")
} else if (k == 0) {
Interval.point(r.one)
Interval.point[A](r.one)
} else if (k == 1) {
this
} else if ((k & 1) == 0) {
Expand Down Expand Up @@ -632,7 +632,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
* result = { p(x) | x ∈ interval }
*/
def translate(p: Polynomial[A])(implicit o: Order[A], ev: Field[A]): Interval[A] = {
val terms2 = p.terms.map { case Term(c, e) => Term(Interval.point(c), e) }
val terms2 = p.terms.map { case Term(c, e) => Term(Interval.point[A](c): Interval[A], e) }
val p2 = Polynomial(terms2)
p2(this)
}
Expand Down Expand Up @@ -798,7 +798,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
def overlap(rhs: Interval[A])(implicit o: Order[A]): Overlap[A] = Overlap(lhs, rhs)
}

case class All[A]() extends Interval[A] {
case class All[A] private[spire] () extends Interval[A] {
def lowerBound: Unbound[A] = Unbound()
def upperBound: Unbound[A] = Unbound()
}
Expand All @@ -808,33 +808,23 @@ case class Above[A] private[spire] (lower: A, flags: Int) extends Interval[A] {
def upperBound: Unbound[A] = Unbound()
}

object Above {
def above[A: Order](a: A): Above[A] = Above(a, 1)
def atOrAbove[A: Order](a: A): Above[A] = Above(a, 0)
}

case class Below[A] private[spire] (upper: A, flags: Int) extends Interval[A] {
def lowerBound: Unbound[A] = Unbound()
def upperBound: ValueBound[A] = if (isOpenUpper(flags)) Open(upper) else Closed(upper)
}

object Below {
def below[A: Order](a: A): Below[A] = Below(a, 2)
def atOrBelow[A: Order](a: A): Below[A] = Below(a, 0)
}

// Bounded, non-empty interval with lower < upper
case class Bounded[A] private[spire] (lower: A, upper: A, flags: Int) extends Interval[A] {
def lowerBound: ValueBound[A] = if (isOpenLower(flags)) Open(lower) else Closed(lower)
def upperBound: ValueBound[A] = if (isOpenUpper(flags)) Open(upper) else Closed(upper)
}

case class Point[A](value: A) extends Interval[A] {
case class Point[A] private[spire] (value: A) extends Interval[A] {
def lowerBound: Closed[A] = Closed(value)
def upperBound: Closed[A] = Closed(value)
}

case class Empty[A]() extends Interval[A] {
case class Empty[A] private[spire] () extends Interval[A] {
def lowerBound: EmptyBound[A] = EmptyBound()
def upperBound: EmptyBound[A] = EmptyBound()
}
Expand All @@ -850,13 +840,23 @@ object Interval {
else
Interval.empty[A]

def empty[A: Order]: Interval[A] = Empty[A]()
// Old methods with a return type of Interval, kept for binary compatibility
private[Interval] def empty[A: Order, B]: Interval[A] = empty[A]
private[Interval] def point[A: Order, B](a: A): Interval[A] = point[A](a)
private[Interval] def zero[A: Order, B](implicit r: Semiring[A]): Interval[A] = zero[A]
private[Interval] def all[A: Order, B]: Interval[A] = all[A]
private[Interval] def above[A: Order, B](a: A): Interval[A] = above[A](a)
private[Interval] def below[A: Order, B](a: A): Interval[A] = below[A](a)
private[Interval] def atOrAbove[A: Order, B](a: A): Interval[A] = atOrAbove[A](a)
private[Interval] def atOrBelow[A: Order, B](a: A): Interval[A] = atOrBelow[A](a)

def empty[A: Order]: Empty[A] = Empty[A]()

def point[A: Order](a: A): Interval[A] = Point(a)
def point[A: Order](a: A): Point[A] = Point(a)

def zero[A: Order](implicit r: Semiring[A]): Interval[A] = Point(r.zero)
def zero[A: Order](implicit r: Semiring[A]): Point[A] = Point(r.zero)

def all[A: Order]: Interval[A] = All[A]()
def all[A: Order]: All[A] = All[A]()

def apply[A: Order](lower: A, upper: A): Interval[A] = closed(lower, upper)

Expand All @@ -877,9 +877,9 @@ object Interval {
*/
def errorBounds(d: Double): Interval[Rational] =
if (d == Double.PositiveInfinity) {
Interval.above(Double.MaxValue)
Interval.above[Rational](Double.MaxValue)
} else if (d == Double.NegativeInfinity) {
Interval.below(Double.MinValue)
Interval.below[Rational](Double.MinValue)
} else if (isNaN(d)) {
Interval.empty[Rational]
} else {
Expand Down Expand Up @@ -912,32 +912,32 @@ object Interval {
*/
private[spire] def fromOrderedBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] =
(lower, upper) match {
case (EmptyBound(), EmptyBound()) => empty
case (EmptyBound(), EmptyBound()) => empty[A]
case (Closed(x), Closed(y)) => Bounded(x, y, closedLowerFlags | closedUpperFlags)
case (Open(x), Open(y)) => Bounded(x, y, openLowerFlags | openUpperFlags)
case (Unbound(), Open(y)) => below(y)
case (Open(x), Unbound()) => above(x)
case (Unbound(), Closed(y)) => atOrBelow(y)
case (Closed(x), Unbound()) => atOrAbove(x)
case (Unbound(), Open(y)) => below[A](y)
case (Open(x), Unbound()) => above[A](x)
case (Unbound(), Closed(y)) => atOrBelow[A](y)
case (Closed(x), Unbound()) => atOrAbove[A](x)
case (Closed(x), Open(y)) => Bounded(x, y, closedLowerFlags | openUpperFlags)
case (Open(x), Closed(y)) => Bounded(x, y, openLowerFlags | closedUpperFlags)
case (Unbound(), Unbound()) => all
case (Unbound(), Unbound()) => all[A]
case (EmptyBound(), _) | (_, EmptyBound()) =>
throw new IllegalArgumentException("invalid empty bound")
}

def fromBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] =
(lower, upper) match {
case (EmptyBound(), EmptyBound()) => empty
case (EmptyBound(), EmptyBound()) => empty[A]
case (Closed(x), Closed(y)) => closed(x, y)
case (Open(x), Open(y)) => open(x, y)
case (Unbound(), Open(y)) => below(y)
case (Open(x), Unbound()) => above(x)
case (Unbound(), Closed(y)) => atOrBelow(y)
case (Closed(x), Unbound()) => atOrAbove(x)
case (Unbound(), Open(y)) => below[A](y)
case (Open(x), Unbound()) => above[A](x)
case (Unbound(), Closed(y)) => atOrBelow[A](y)
case (Closed(x), Unbound()) => atOrAbove[A](x)
case (Closed(x), Open(y)) => openUpper(x, y)
case (Open(x), Closed(y)) => openLower(x, y)
case (Unbound(), Unbound()) => all
case (Unbound(), Unbound()) => all[A]
case (EmptyBound(), _) | (_, EmptyBound()) =>
throw new IllegalArgumentException("invalid empty bound")
}
Expand All @@ -954,10 +954,10 @@ object Interval {
if (lower < upper) Bounded(lower, upper, 1) else Interval.empty[A]
def openUpper[A: Order](lower: A, upper: A): Interval[A] =
if (lower < upper) Bounded(lower, upper, 2) else Interval.empty[A]
def above[A: Order](a: A): Interval[A] = Above.above(a)
def below[A: Order](a: A): Interval[A] = Below.below(a)
def atOrAbove[A: Order](a: A): Interval[A] = Above.atOrAbove(a)
def atOrBelow[A: Order](a: A): Interval[A] = Below.atOrBelow(a)
def above[A: Order](a: A): Above[A] = Above(a, 1)
def below[A: Order](a: A): Below[A] = Below(a, 2)
def atOrAbove[A: Order](a: A): Above[A] = Above(a, 0)
def atOrBelow[A: Order](a: A): Below[A] = Below(a, 0)

private val NullRe = "^ *\\( *Ø *\\) *$".r
private val SingleRe = "^ *\\[ *([^,]+) *\\] *$".r
Expand All @@ -966,14 +966,14 @@ object Interval {
def apply(s: String): Interval[Rational] =
s match {
case NullRe() => Interval.empty[Rational]
case SingleRe(x) => Interval.point(Rational(x))
case SingleRe(x) => Interval.point[Rational](Rational(x))
case PairRe(left, x, y, right) =>
(left, x, y, right) match {
case ("(", "-∞", "", ")") => Interval.all[Rational]
case ("(", "-∞", y, ")") => Interval.below(Rational(y))
case ("(", "-∞", y, "]") => Interval.atOrBelow(Rational(y))
case ("(", x, "", ")") => Interval.above(Rational(x))
case ("[", x, "", ")") => Interval.atOrAbove(Rational(x))
case ("(", "-∞", y, ")") => Interval.below[Rational](Rational(y))
case ("(", "-∞", y, "]") => Interval.atOrBelow[Rational](Rational(y))
case ("(", x, "", ")") => Interval.above[Rational](Rational(x))
case ("[", x, "", ")") => Interval.atOrAbove[Rational](Rational(x))
case ("[", x, y, "]") => Interval.closed(Rational(x), Rational(y))
case ("(", x, y, ")") => Interval.open(Rational(x), Rational(y))
case ("[", x, y, ")") => Interval.openUpper(Rational(x), Rational(y))
Expand All @@ -990,7 +990,7 @@ object Interval {

implicit def semiring[A](implicit ev: Ring[A], o: Order[A]): Semiring[Interval[A]] =
new Semiring[Interval[A]] {
def zero: Interval[A] = Interval.point(ev.zero)
def zero: Interval[A] = Interval.point[A](ev.zero)
def plus(x: Interval[A], y: Interval[A]): Interval[A] = x + y
def times(x: Interval[A], y: Interval[A]): Interval[A] = x * y
override def pow(x: Interval[A], k: Int): Interval[A] = x.pow(k)
Expand Down
20 changes: 6 additions & 14 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -666,20 +666,12 @@ type classes (ranging from `AdditiveSemigroup[A]` for `+` to
Intervals may be unbounded on either side, and bounds can be open or
closed. (An interval includes closed boundaries, but not open
boundaries). Here are some string representations of various
intervals, and how to create their equivalents in `spire`:

* `[3, 6]` the set of values between 3 and 6 (including both) - `Interval.fromBounds(Closed(3), Closed(6))`
* `(2, 4)` the set of values between 2 and 4 (excluding both) - `Interval.fromBounds(Open(2), Open(4))`
* `[1, 2)` half-open set, including 1 but not 2 - `Interval.fromBounds(Closed(1), Open(2))`
* `(-∞, 5)` the set of values less than 5 - `Below.below(5)`
* `[1, ∞]` the set of values greater than or equal to 1 - `Above.atOrAbove(1)`
* `Ø` the empty set - `Empty()`
* `(-∞, ∞)` the set of all values - `All()`

Note that `Interval` may be used in place of `Below`/`Above`/
`Empty`/`All`, though these methods share a return value of the
`Interval` supertype from which lower/upper bound values which are
known can't be extracted.
intervals:

* `[3, 6]` the set of values between 3 and 6 (including both).
* `(2, 4)` the set of values between 2 and 4 (excluding both).
* `[1, 2)` half-open set, including 1 but not 2.
* `(-∞, 5)` the set of values less than 5.

Intervals model continuous spaces, even if the type A is discrete. So
for instance when `(3, 4)` is an `Interval[Int]` it is not considered
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class IntervalGeometricPartialOrderSuite extends munit.FunSuite {
test("[2, 3] cannot be compared to empty") { assert(closed(2, 3).partialCompare(open(2, 2)).isNaN) }
test("Minimal and maximal elements of {[1], [2, 3], [2, 4]}") {
val intervals = Seq(point(1), closed(2, 3), closed(2, 4))
assertEquals(intervals.pmin.toSet, Set(point(1)))
assertEquals(intervals.pmax.toSet, Set(closed(2, 3), closed(2, 4)))
assertEquals(intervals.pmin.toSet, Set[Interval[Int]](point(1)))
assertEquals(intervals.pmax.toSet, Set[Interval[Int]](closed(2, 3), closed(2, 4)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {

import spire.algebra.{Order, PartialOrder}
forAll { (x: Rational, y: Rational) =>
val a = Interval.point(x)
val b = Interval.point(y)
val a: Interval[Rational] = Interval.point(x)
val b: Interval[Rational] = Interval.point(y)
val order = PartialOrder[Interval[Rational]].tryCompare(a, b).get == Order[Rational].compare(x, y)
val min = a.pmin(b) match {
case Some(Point(vmin)) => vmin == x.min(y)
Expand All @@ -183,8 +183,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {
forAll { (a: Rational, w: Positive[Rational]) =>
val b = a + w.num
// a < b
val i = Interval.atOrBelow(a)
val j = Interval.atOrAbove(b)
val i: Interval[Rational] = Interval.atOrBelow(a)
val j: Interval[Rational] = Interval.atOrAbove(b)
(i < j) &&
!(i >= j) &&
(j > i) &&
Expand All @@ -197,8 +197,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {
forAll { (a: Rational, w: NonNegative[Rational]) =>
val b = a - w.num
// a >= b
val i = Interval.atOrBelow(a)
val j = Interval.atOrAbove(b)
val i: Interval[Rational] = Interval.atOrBelow(a)
val j: Interval[Rational] = Interval.atOrAbove(b)
i.partialCompare(j).isNaN &&
j.partialCompare(i).isNaN
}
Expand All @@ -207,8 +207,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {
property("(-inf, inf) does not compare with [a, b]") {
import spire.optional.intervalGeometricPartialOrder._
forAll { (a: Rational, b: Rational) =>
val i = Interval.all[Rational]
val j = Interval.closed(a, b)
val i: Interval[Rational] = Interval.all[Rational]
val j: Interval[Rational] = Interval.closed(a, b)
i.partialCompare(j).isNaN &&
j.partialCompare(i).isNaN
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class IntervalSubsetPartialOrderSuite extends munit.FunSuite {

test("Minimal and maximal elements of {[1, 3], [3], [2], [1]} by subset partial order") {
val intervals = Seq(closed(1, 3), point(3), point(2), point(1))
assertEquals(intervals.pmin.toSet, Set(point(1), point(2), point(3)))
assertEquals(intervals.pmax.toSet, Set(closed(1, 3)))
assertEquals(intervals.pmin.toSet, Set(point(1), point(2), point(3)): Set[Interval[Int]])
assertEquals(intervals.pmax.toSet, Set(closed(1, 3)): Set[Interval[Int]])
}
}

0 comments on commit 55846fe

Please sign in to comment.