diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index c93a1719e9..ce4e1eb9b1 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -256,6 +256,99 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { result } + /** + * take a certain amount of items from the front of the Chain + */ + final def take(count: Long): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] = + arg match { + case Wrap(seq) => + if (count == 1) { + lhs.append(seq.head) + } else { + // count > 1 + val taken = + if (count < Int.MaxValue) seq.take(count.toInt) + else seq.take(Int.MaxValue) + // we may have not taken all of count + val newCount = count - taken.length + val wrapped = Wrap(taken) + // this is more efficient than using concat + val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped) + rhs match { + case rhsNE: NonEmpty[A] if newCount > 0L => + // we have to keep taking on the rhs + go(newLhs, newCount, rhsNE, Empty) + case _ => + newLhs + } + } + case Append(l, r) => + go(lhs, count, l, if (rhs.isEmpty) r else Append(r, rhs)) + case s @ Singleton(_) => + // due to the invariant count >= 1 + val newLhs = if (lhs.isEmpty) s else Append(lhs, s) + rhs match { + case rhsNE: NonEmpty[A] if count > 1L => + go(newLhs, count - 1L, rhsNE, Empty) + case _ => newLhs + } + } + + this match { + case ne: NonEmpty[A] if count > 0L => + go(Empty, count, ne, Empty) + case _ => Empty + } + } + + /** + * take a certain amount of items from the back of the Chain + */ + final def takeRight(count: Long): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] = + arg match { + case Wrap(seq) => + if (count == 1L) { + seq.last +: rhs + } else { + // count > 1 + val taken = + if (count < Int.MaxValue) seq.takeRight(count.toInt) + else seq.takeRight(Int.MaxValue) + // we may have not taken all of count + val newCount = count - taken.length + val wrapped = Wrap(taken) + val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs) + lhs match { + case lhsNE: NonEmpty[A] if newCount > 0 => + go(Empty, newCount, lhsNE, newRhs) + case _ => newRhs + } + } + case Append(l, r) => + go(if (lhs.isEmpty) l else Append(lhs, l), count, r, rhs) + case s @ Singleton(_) => + // due to the invariant count >= 1 + val newRhs = if (rhs.isEmpty) s else Append(s, rhs) + lhs match { + case lhsNE: NonEmpty[A] if count > 1 => + go(Empty, count - 1, lhsNE, newRhs) + case _ => newRhs + } + } + + this match { + case ne: NonEmpty[A] if count > 0L => + go(Empty, count, ne, Empty) + case _ => Empty + } + } + /** * Drops longest prefix of elements that satisfy a predicate. * @@ -275,6 +368,101 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { go(this) } + /** + * Drop a certain amount of items from the front of the Chain + */ + final def drop(count: Long): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] = + arg match { + case Wrap(seq) => + val dropped = if (count < Int.MaxValue) seq.drop(count.toInt) else seq.drop(Int.MaxValue) + if (dropped.isEmpty) { + // we may have not dropped all of count + val newCount = count - seq.length + rhs match { + case rhsNE: NonEmpty[A] if newCount > 0 => + // we have to keep dropping on the rhs + go(newCount, rhsNE, Empty) + case _ => + // we know that count >= seq.length else we wouldn't be empty + // so in this case, it is exactly count == seq.length + rhs + } + } else { + // dropped is not empty + val wrapped = Wrap(dropped) + // we must be done + if (rhs.isEmpty) wrapped else Append(wrapped, rhs) + } + case Append(l, r) => + go(count, l, if (rhs.isEmpty) r else Append(r, rhs)) + case Singleton(_) => + // due to the invariant count >= 1 + rhs match { + case rhsNE: NonEmpty[A] if count > 1L => + go(count - 1L, rhsNE, Empty) + case _ => + rhs + } + } + + this match { + case ne: NonEmpty[A] if count > 0L => + go(count, ne, Empty) + case _ => this + } + } + + /** + * Drop a certain amount of items from the back of the Chain + */ + final def dropRight(count: Long): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(lhs: Chain[A], count: Long, arg: NonEmpty[A]): Chain[A] = + arg match { + case Wrap(seq) => + val dropped = if (count < Int.MaxValue) seq.dropRight(count.toInt) else seq.dropRight(Int.MaxValue) + if (dropped.isEmpty) { + // we may have not dropped all of count + val newCount = count - seq.length + lhs match { + case lhsNE: NonEmpty[A] if newCount > 0L => + // we have to keep dropping on the lhs + go(Empty, newCount, lhsNE) + case _ => + // we know that count >= seq.length else we wouldn't be empty + // so in this case, it is exactly count == seq.length + lhs + } + } else { + // we must be done + // note: dropped.nonEmpty + val wrapped = Wrap(dropped) + if (lhs.isEmpty) wrapped else Append(lhs, wrapped) + } + case Append(l, r) => + go(if (lhs.isEmpty) l else Append(lhs, l), count, r) + case Singleton(_) => + // due to the invariant count >= 1 + lhs match { + case lhsNE: NonEmpty[A] if count > 1L => + go(Empty, count - 1L, lhsNE) + case _ => + lhs + } + } + + this match { + case ne: NonEmpty[A] if count > 0L => + go(Empty, count, ne) + case _ => + this + } + } + /** * Folds over the elements from right to left using the supplied initial value and function. */ diff --git a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala index 4a920ad4c4..b25e2022e3 100644 --- a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala @@ -448,4 +448,38 @@ class ChainSuite extends CatsSuite { assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn)) } } + + private val genChainDropTakeArgs = + Arbitrary.arbitrary[Chain[Int]].flatMap { chain => + // Bias to values close to the length + Gen + .oneOf( + Gen.choose(Int.MinValue, Int.MaxValue), + Gen.choose(-1, chain.length.toInt + 1) + ) + .map((chain, _)) + } + + test("drop(cnt).toList == toList.drop(cnt)") { + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => + assertEquals(chain.drop(count).toList, chain.toList.drop(count)) + } + } + + test("dropRight(cnt).toList == toList.dropRight(cnt)") { + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => + assertEquals(chain.dropRight(count).toList, chain.toList.dropRight(count)) + } + } + test("take(cnt).toList == toList.take(cnt)") { + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => + assertEquals(chain.take(count).toList, chain.toList.take(count)) + } + } + + test("takeRight(cnt).toList == toList.takeRight(cnt)") { + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => + assertEquals(chain.takeRight(count).toList, chain.toList.takeRight(count)) + } + } }