Skip to content

Commit

Permalink
Add take/takeRight/drop/dropRight to Chain
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Jan 3, 2025
1 parent d702505 commit 3b2c4bd
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 0 deletions.
164 changes: 164 additions & 0 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,90 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
result
}

/**
* take a certain amount of items from the front of the Chain
*/
final def take(count: Int): Chain[A] = {
// invariant count >= 1
@tailrec
def go(lhs: Chain[A], count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] =
arg match {
case Wrap(seq) =>
if (count == 1) {
lhs.append(seq(0))
} else {
// count > 1
val taken = seq.take(count)
// we may have not takeped all of count
val newCount = count - taken.length
val newLhs = lhs.concat(Wrap(taken))
if (newCount > 0) {
// we have to keep takeping on the rhs
go(newLhs, newCount, rhs, Chain.nil)
} else {
// newCount == 0, we have taken enough
newLhs
}
}
case Append(l, r) =>
go(lhs, count, l, r.concat(rhs))
case s @ Singleton(_) =>
// due to the invariant count >= 1
val newLhs = if (lhs.isEmpty) s else Append(lhs, s)
if (count > 1) {
go(newLhs, count - 1, rhs, Chain.nil)
} else newLhs
case Empty =>
if (rhs.isEmpty) lhs
else go(lhs, count, rhs, Chain.nil)
}

if (count <= 0) Empty
else go(Empty, count, this, Empty)
}

/**
* take a certain amount of items from the back of the Chain
*/
final def takeRight(count: Int): Chain[A] = {
// invariant count >= 1
@tailrec
def go(lhs: Chain[A], count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] =
arg match {
case Wrap(seq) =>
if (count == 1) {
lhs.append(seq.last)
} else {
// count > 1
val taken = seq.takeRight(count)
// we may have not takeped all of count
val newCount = count - taken.length
val newRhs = Wrap(taken).concat(rhs)
if (newCount > 0) {
// we have to keep takeping on the rhs
go(Chain.nil, newCount, lhs, newRhs)
} else {
// newCount == 0, we have taken enough
newRhs
}
}
case Append(l, r) =>
go(lhs.concat(l), count, r, rhs)
case s @ Singleton(_) =>
// due to the invariant count >= 1
val newRhs = if (rhs.isEmpty) s else Append(s, rhs)
if (count > 1) {
go(Empty, count - 1, lhs, newRhs)
} else newRhs
case Empty =>
if (lhs.isEmpty) rhs
else go(Chain.nil, count, lhs, rhs)
}

if (count <= 0) Empty
else go(Empty, count, this, Empty)
}

/**
* Drops longest prefix of elements that satisfy a predicate.
*
Expand All @@ -275,6 +359,86 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
go(this)
}

/**
* Drop a certain amount of items from the front of the Chain
*/
final def drop(count: Int): Chain[A] = {
// invariant count >= 1
@tailrec
def go(count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] =
arg match {
case Wrap(seq) =>
val dropped = seq.drop(count)
if (dropped.isEmpty) {
// we may have not dropped all of count
val newCount = count - seq.length
if (newCount > 0) {
// we have to keep dropping on the rhs
go(newCount, rhs, Chain.nil)
} else {
// we know that count >= seq.length else we wouldn't be empty
// so in this case, it is exactly count == seq.length
rhs
}
} else {
// we must be done
Chain.fromSeq(dropped).concat(rhs)
}
case Append(l, r) =>
go(count, l, r.concat(rhs))
case Singleton(_) =>
// due to the invariant count >= 1
if (count > 1) go(count - 1, rhs, Chain.nil)
else rhs
case Empty =>
if (rhs.isEmpty) Empty
else go(count, rhs, Chain.nil)
}

if (count <= 0) this
else go(count, this, Empty)
}

/**
* Drop a certain amount of items from the back of the Chain
*/
final def dropRight(count: Int): Chain[A] = {
// invariant count >= 1
@tailrec
def go(lhs: Chain[A], count: Int, arg: Chain[A]): Chain[A] =
arg match {
case Wrap(seq) =>
val dropped = seq.dropRight(count)
if (dropped.isEmpty) {
// we may have not dropped all of count
val newCount = count - seq.length
if (newCount > 0) {
// we have to keep dropping on the rhs
go(Chain.nil, newCount, lhs)
} else {
// we know that count >= seq.length else we wouldn't be empty
// so in this case, it is exactly count == seq.length
lhs
}
} else {
// we must be done
lhs.concat(Chain.fromSeq(dropped))
}
case Append(l, r) =>
go(lhs.concat(l), count, r)
case Singleton(_) =>
// due to the invariant count >= 1
if (count > 1) go(Chain.nil, count - 1, lhs)
else lhs
case Empty =>
if (lhs.isEmpty) Empty
else go(Chain.nil, count, lhs)
}

if (count <= 0) this
else go(Empty, count, this)
}

/**
* Folds over the elements from right to left using the supplied initial value and function.
*/
Expand Down
23 changes: 23 additions & 0 deletions tests/shared/src/test/scala/cats/tests/ChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,27 @@ class ChainSuite extends CatsSuite {
assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn))
}
}

test("drop(cnt).toList == toList.drop(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
assert(chain.drop(count).toList == chain.toList.drop(count))
}
}

test("dropRight(cnt).toList == toList.dropRight(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
assert(chain.dropRight(count).toList == chain.toList.dropRight(count))
}
}
test("take(cnt).toList == toList.take(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
assert(chain.take(count).toList == chain.toList.take(count))
}
}

test("takeRight(cnt).toList == toList.takeRight(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
assert(chain.takeRight(count).toList == chain.toList.takeRight(count))
}
}
}

0 comments on commit 3b2c4bd

Please sign in to comment.