From 55846fed76bf6f1681ba77603ae668bf8bda362c Mon Sep 17 00:00:00 2001 From: Laurence Warne Date: Wed, 20 Mar 2024 13:04:29 +0000 Subject: [PATCH] Change return types --- core/src/main/scala/spire/math/Interval.scala | 92 +++++++++---------- docs/guide.md | 20 ++-- .../IntervalGeometricPartialOrderSuite.scala | 4 +- .../spire/math/IntervalScalaCheckSuite.scala | 16 ++-- .../IntervalSubsetPartialOrderSuite.scala | 4 +- 5 files changed, 64 insertions(+), 72 deletions(-) diff --git a/core/src/main/scala/spire/math/Interval.scala b/core/src/main/scala/spire/math/Interval.scala index 61dea4c65..fd9bc1575 100644 --- a/core/src/main/scala/spire/math/Interval.scala +++ b/core/src/main/scala/spire/math/Interval.scala @@ -235,7 +235,7 @@ sealed abstract class Interval[A] extends Serializable { lhs => case Below(upper, uf) => List(Above(upper, upperFlagToLower(reverseUpperFlag(uf)))) case Point(p) => - List(Interval.below(p), Interval.above(p)) + List(Interval.below[A](p), Interval.above[A](p)) case Bounded(lower, upper, flags) => val lx = lowerFlagToUpper(reverseLowerFlag(lowerFlag(flags))) val ux = upperFlagToLower(reverseUpperFlag(upperFlag(flags))) @@ -251,7 +251,7 @@ sealed abstract class Interval[A] extends Serializable { lhs => } def split(t: A)(implicit o: Order[A]): (Interval[A], Interval[A]) = - (this.intersect(Interval.below(t)), this.intersect(Interval.above(t))) + (this.intersect(Interval.below[A](t)), this.intersect(Interval.above[A](t))) def splitAtZero(implicit o: Order[A], ev: AdditiveMonoid[A]): (Interval[A], Interval[A]) = split(ev.zero) @@ -295,7 +295,7 @@ sealed abstract class Interval[A] extends Serializable { lhs => else if (upper > x) Bounded(m.zero, upper, upperFlag(fs)) else Bounded(m.zero, x, lowerFlagToUpper(fs) & upperFlag(fs)) case _ => // Above or Below - Interval.atOrAbove(m.zero) + Interval.atOrAbove[A](m.zero) } } else if (hasBelow(m.zero)) { -this @@ -565,7 +565,7 @@ sealed abstract class Interval[A] extends Serializable { lhs => if (k < 0) { throw new IllegalArgumentException(s"negative exponent: $k") } else if (k == 0) { - Interval.point(r.one) + Interval.point[A](r.one) } else if (k == 1) { this } else if ((k & 1) == 0) { @@ -632,7 +632,7 @@ sealed abstract class Interval[A] extends Serializable { lhs => * result = { p(x) | x ∈ interval } */ def translate(p: Polynomial[A])(implicit o: Order[A], ev: Field[A]): Interval[A] = { - val terms2 = p.terms.map { case Term(c, e) => Term(Interval.point(c), e) } + val terms2 = p.terms.map { case Term(c, e) => Term(Interval.point[A](c): Interval[A], e) } val p2 = Polynomial(terms2) p2(this) } @@ -798,7 +798,7 @@ sealed abstract class Interval[A] extends Serializable { lhs => def overlap(rhs: Interval[A])(implicit o: Order[A]): Overlap[A] = Overlap(lhs, rhs) } -case class All[A]() extends Interval[A] { +case class All[A] private[spire] () extends Interval[A] { def lowerBound: Unbound[A] = Unbound() def upperBound: Unbound[A] = Unbound() } @@ -808,33 +808,23 @@ case class Above[A] private[spire] (lower: A, flags: Int) extends Interval[A] { def upperBound: Unbound[A] = Unbound() } -object Above { - def above[A: Order](a: A): Above[A] = Above(a, 1) - def atOrAbove[A: Order](a: A): Above[A] = Above(a, 0) -} - case class Below[A] private[spire] (upper: A, flags: Int) extends Interval[A] { def lowerBound: Unbound[A] = Unbound() def upperBound: ValueBound[A] = if (isOpenUpper(flags)) Open(upper) else Closed(upper) } -object Below { - def below[A: Order](a: A): Below[A] = Below(a, 2) - def atOrBelow[A: Order](a: A): Below[A] = Below(a, 0) -} - // Bounded, non-empty interval with lower < upper case class Bounded[A] private[spire] (lower: A, upper: A, flags: Int) extends Interval[A] { def lowerBound: ValueBound[A] = if (isOpenLower(flags)) Open(lower) else Closed(lower) def upperBound: ValueBound[A] = if (isOpenUpper(flags)) Open(upper) else Closed(upper) } -case class Point[A](value: A) extends Interval[A] { +case class Point[A] private[spire] (value: A) extends Interval[A] { def lowerBound: Closed[A] = Closed(value) def upperBound: Closed[A] = Closed(value) } -case class Empty[A]() extends Interval[A] { +case class Empty[A] private[spire] () extends Interval[A] { def lowerBound: EmptyBound[A] = EmptyBound() def upperBound: EmptyBound[A] = EmptyBound() } @@ -850,13 +840,23 @@ object Interval { else Interval.empty[A] - def empty[A: Order]: Interval[A] = Empty[A]() + // Old methods with a return type of Interval, kept for binary compatibility + private[Interval] def empty[A: Order, B]: Interval[A] = empty[A] + private[Interval] def point[A: Order, B](a: A): Interval[A] = point[A](a) + private[Interval] def zero[A: Order, B](implicit r: Semiring[A]): Interval[A] = zero[A] + private[Interval] def all[A: Order, B]: Interval[A] = all[A] + private[Interval] def above[A: Order, B](a: A): Interval[A] = above[A](a) + private[Interval] def below[A: Order, B](a: A): Interval[A] = below[A](a) + private[Interval] def atOrAbove[A: Order, B](a: A): Interval[A] = atOrAbove[A](a) + private[Interval] def atOrBelow[A: Order, B](a: A): Interval[A] = atOrBelow[A](a) + + def empty[A: Order]: Empty[A] = Empty[A]() - def point[A: Order](a: A): Interval[A] = Point(a) + def point[A: Order](a: A): Point[A] = Point(a) - def zero[A: Order](implicit r: Semiring[A]): Interval[A] = Point(r.zero) + def zero[A: Order](implicit r: Semiring[A]): Point[A] = Point(r.zero) - def all[A: Order]: Interval[A] = All[A]() + def all[A: Order]: All[A] = All[A]() def apply[A: Order](lower: A, upper: A): Interval[A] = closed(lower, upper) @@ -877,9 +877,9 @@ object Interval { */ def errorBounds(d: Double): Interval[Rational] = if (d == Double.PositiveInfinity) { - Interval.above(Double.MaxValue) + Interval.above[Rational](Double.MaxValue) } else if (d == Double.NegativeInfinity) { - Interval.below(Double.MinValue) + Interval.below[Rational](Double.MinValue) } else if (isNaN(d)) { Interval.empty[Rational] } else { @@ -912,32 +912,32 @@ object Interval { */ private[spire] def fromOrderedBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] = (lower, upper) match { - case (EmptyBound(), EmptyBound()) => empty + case (EmptyBound(), EmptyBound()) => empty[A] case (Closed(x), Closed(y)) => Bounded(x, y, closedLowerFlags | closedUpperFlags) case (Open(x), Open(y)) => Bounded(x, y, openLowerFlags | openUpperFlags) - case (Unbound(), Open(y)) => below(y) - case (Open(x), Unbound()) => above(x) - case (Unbound(), Closed(y)) => atOrBelow(y) - case (Closed(x), Unbound()) => atOrAbove(x) + case (Unbound(), Open(y)) => below[A](y) + case (Open(x), Unbound()) => above[A](x) + case (Unbound(), Closed(y)) => atOrBelow[A](y) + case (Closed(x), Unbound()) => atOrAbove[A](x) case (Closed(x), Open(y)) => Bounded(x, y, closedLowerFlags | openUpperFlags) case (Open(x), Closed(y)) => Bounded(x, y, openLowerFlags | closedUpperFlags) - case (Unbound(), Unbound()) => all + case (Unbound(), Unbound()) => all[A] case (EmptyBound(), _) | (_, EmptyBound()) => throw new IllegalArgumentException("invalid empty bound") } def fromBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] = (lower, upper) match { - case (EmptyBound(), EmptyBound()) => empty + case (EmptyBound(), EmptyBound()) => empty[A] case (Closed(x), Closed(y)) => closed(x, y) case (Open(x), Open(y)) => open(x, y) - case (Unbound(), Open(y)) => below(y) - case (Open(x), Unbound()) => above(x) - case (Unbound(), Closed(y)) => atOrBelow(y) - case (Closed(x), Unbound()) => atOrAbove(x) + case (Unbound(), Open(y)) => below[A](y) + case (Open(x), Unbound()) => above[A](x) + case (Unbound(), Closed(y)) => atOrBelow[A](y) + case (Closed(x), Unbound()) => atOrAbove[A](x) case (Closed(x), Open(y)) => openUpper(x, y) case (Open(x), Closed(y)) => openLower(x, y) - case (Unbound(), Unbound()) => all + case (Unbound(), Unbound()) => all[A] case (EmptyBound(), _) | (_, EmptyBound()) => throw new IllegalArgumentException("invalid empty bound") } @@ -954,10 +954,10 @@ object Interval { if (lower < upper) Bounded(lower, upper, 1) else Interval.empty[A] def openUpper[A: Order](lower: A, upper: A): Interval[A] = if (lower < upper) Bounded(lower, upper, 2) else Interval.empty[A] - def above[A: Order](a: A): Interval[A] = Above.above(a) - def below[A: Order](a: A): Interval[A] = Below.below(a) - def atOrAbove[A: Order](a: A): Interval[A] = Above.atOrAbove(a) - def atOrBelow[A: Order](a: A): Interval[A] = Below.atOrBelow(a) + def above[A: Order](a: A): Above[A] = Above(a, 1) + def below[A: Order](a: A): Below[A] = Below(a, 2) + def atOrAbove[A: Order](a: A): Above[A] = Above(a, 0) + def atOrBelow[A: Order](a: A): Below[A] = Below(a, 0) private val NullRe = "^ *\\( *Ø *\\) *$".r private val SingleRe = "^ *\\[ *([^,]+) *\\] *$".r @@ -966,14 +966,14 @@ object Interval { def apply(s: String): Interval[Rational] = s match { case NullRe() => Interval.empty[Rational] - case SingleRe(x) => Interval.point(Rational(x)) + case SingleRe(x) => Interval.point[Rational](Rational(x)) case PairRe(left, x, y, right) => (left, x, y, right) match { case ("(", "-∞", "∞", ")") => Interval.all[Rational] - case ("(", "-∞", y, ")") => Interval.below(Rational(y)) - case ("(", "-∞", y, "]") => Interval.atOrBelow(Rational(y)) - case ("(", x, "∞", ")") => Interval.above(Rational(x)) - case ("[", x, "∞", ")") => Interval.atOrAbove(Rational(x)) + case ("(", "-∞", y, ")") => Interval.below[Rational](Rational(y)) + case ("(", "-∞", y, "]") => Interval.atOrBelow[Rational](Rational(y)) + case ("(", x, "∞", ")") => Interval.above[Rational](Rational(x)) + case ("[", x, "∞", ")") => Interval.atOrAbove[Rational](Rational(x)) case ("[", x, y, "]") => Interval.closed(Rational(x), Rational(y)) case ("(", x, y, ")") => Interval.open(Rational(x), Rational(y)) case ("[", x, y, ")") => Interval.openUpper(Rational(x), Rational(y)) @@ -990,7 +990,7 @@ object Interval { implicit def semiring[A](implicit ev: Ring[A], o: Order[A]): Semiring[Interval[A]] = new Semiring[Interval[A]] { - def zero: Interval[A] = Interval.point(ev.zero) + def zero: Interval[A] = Interval.point[A](ev.zero) def plus(x: Interval[A], y: Interval[A]): Interval[A] = x + y def times(x: Interval[A], y: Interval[A]): Interval[A] = x * y override def pow(x: Interval[A], k: Int): Interval[A] = x.pow(k) diff --git a/docs/guide.md b/docs/guide.md index 1ff01ae2e..2948a36cc 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -666,20 +666,12 @@ type classes (ranging from `AdditiveSemigroup[A]` for `+` to Intervals may be unbounded on either side, and bounds can be open or closed. (An interval includes closed boundaries, but not open boundaries). Here are some string representations of various -intervals, and how to create their equivalents in `spire`: - - * `[3, 6]` the set of values between 3 and 6 (including both) - `Interval.fromBounds(Closed(3), Closed(6))` - * `(2, 4)` the set of values between 2 and 4 (excluding both) - `Interval.fromBounds(Open(2), Open(4))` - * `[1, 2)` half-open set, including 1 but not 2 - `Interval.fromBounds(Closed(1), Open(2))` - * `(-∞, 5)` the set of values less than 5 - `Below.below(5)` - * `[1, ∞]` the set of values greater than or equal to 1 - `Above.atOrAbove(1)` - * `Ø` the empty set - `Empty()` - * `(-∞, ∞)` the set of all values - `All()` - -Note that `Interval` may be used in place of `Below`/`Above`/ -`Empty`/`All`, though these methods share a return value of the -`Interval` supertype from which lower/upper bound values which are -known can't be extracted. +intervals: + + * `[3, 6]` the set of values between 3 and 6 (including both). + * `(2, 4)` the set of values between 2 and 4 (excluding both). + * `[1, 2)` half-open set, including 1 but not 2. + * `(-∞, 5)` the set of values less than 5. Intervals model continuous spaces, even if the type A is discrete. So for instance when `(3, 4)` is an `Interval[Int]` it is not considered diff --git a/tests/shared/src/test/scala/spire/math/IntervalGeometricPartialOrderSuite.scala b/tests/shared/src/test/scala/spire/math/IntervalGeometricPartialOrderSuite.scala index 3e7b2de88..d6a31986a 100644 --- a/tests/shared/src/test/scala/spire/math/IntervalGeometricPartialOrderSuite.scala +++ b/tests/shared/src/test/scala/spire/math/IntervalGeometricPartialOrderSuite.scala @@ -38,7 +38,7 @@ class IntervalGeometricPartialOrderSuite extends munit.FunSuite { test("[2, 3] cannot be compared to empty") { assert(closed(2, 3).partialCompare(open(2, 2)).isNaN) } test("Minimal and maximal elements of {[1], [2, 3], [2, 4]}") { val intervals = Seq(point(1), closed(2, 3), closed(2, 4)) - assertEquals(intervals.pmin.toSet, Set(point(1))) - assertEquals(intervals.pmax.toSet, Set(closed(2, 3), closed(2, 4))) + assertEquals(intervals.pmin.toSet, Set[Interval[Int]](point(1))) + assertEquals(intervals.pmax.toSet, Set[Interval[Int]](closed(2, 3), closed(2, 4))) } } diff --git a/tests/shared/src/test/scala/spire/math/IntervalScalaCheckSuite.scala b/tests/shared/src/test/scala/spire/math/IntervalScalaCheckSuite.scala index 0ac78d2de..f5664bcd9 100644 --- a/tests/shared/src/test/scala/spire/math/IntervalScalaCheckSuite.scala +++ b/tests/shared/src/test/scala/spire/math/IntervalScalaCheckSuite.scala @@ -162,8 +162,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite { import spire.algebra.{Order, PartialOrder} forAll { (x: Rational, y: Rational) => - val a = Interval.point(x) - val b = Interval.point(y) + val a: Interval[Rational] = Interval.point(x) + val b: Interval[Rational] = Interval.point(y) val order = PartialOrder[Interval[Rational]].tryCompare(a, b).get == Order[Rational].compare(x, y) val min = a.pmin(b) match { case Some(Point(vmin)) => vmin == x.min(y) @@ -183,8 +183,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite { forAll { (a: Rational, w: Positive[Rational]) => val b = a + w.num // a < b - val i = Interval.atOrBelow(a) - val j = Interval.atOrAbove(b) + val i: Interval[Rational] = Interval.atOrBelow(a) + val j: Interval[Rational] = Interval.atOrAbove(b) (i < j) && !(i >= j) && (j > i) && @@ -197,8 +197,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite { forAll { (a: Rational, w: NonNegative[Rational]) => val b = a - w.num // a >= b - val i = Interval.atOrBelow(a) - val j = Interval.atOrAbove(b) + val i: Interval[Rational] = Interval.atOrBelow(a) + val j: Interval[Rational] = Interval.atOrAbove(b) i.partialCompare(j).isNaN && j.partialCompare(i).isNaN } @@ -207,8 +207,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite { property("(-inf, inf) does not compare with [a, b]") { import spire.optional.intervalGeometricPartialOrder._ forAll { (a: Rational, b: Rational) => - val i = Interval.all[Rational] - val j = Interval.closed(a, b) + val i: Interval[Rational] = Interval.all[Rational] + val j: Interval[Rational] = Interval.closed(a, b) i.partialCompare(j).isNaN && j.partialCompare(i).isNaN } diff --git a/tests/shared/src/test/scala/spire/math/IntervalSubsetPartialOrderSuite.scala b/tests/shared/src/test/scala/spire/math/IntervalSubsetPartialOrderSuite.scala index 766c1f12b..3d6d60f7b 100644 --- a/tests/shared/src/test/scala/spire/math/IntervalSubsetPartialOrderSuite.scala +++ b/tests/shared/src/test/scala/spire/math/IntervalSubsetPartialOrderSuite.scala @@ -25,7 +25,7 @@ class IntervalSubsetPartialOrderSuite extends munit.FunSuite { test("Minimal and maximal elements of {[1, 3], [3], [2], [1]} by subset partial order") { val intervals = Seq(closed(1, 3), point(3), point(2), point(1)) - assertEquals(intervals.pmin.toSet, Set(point(1), point(2), point(3))) - assertEquals(intervals.pmax.toSet, Set(closed(1, 3))) + assertEquals(intervals.pmin.toSet, Set(point(1), point(2), point(3)): Set[Interval[Int]]) + assertEquals(intervals.pmax.toSet, Set(closed(1, 3)): Set[Interval[Int]]) } }