Skip to content

Commit 39f9ec6

Browse files
splitWhen & splitWhenM for Foldable
1 parent be4a99d commit 39f9ec6

3 files changed

Lines changed: 74 additions & 0 deletions

File tree

core/src/main/scala/cats/Foldable.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,49 @@ trait Foldable[F[_]] extends UnorderedFoldable[F] with FoldableNFunctions[F] { s
960960
import cats.instances.either.*
961961
partitionBifoldM[G, Either, A, B, C](fa)(f)(A, M, Bifoldable[Either])
962962
}
963+
964+
/**
965+
* Split this Foldable into a List of Lists based on a predicate.
966+
* The behaviour is aimed to be identical to that of haskell's `splitWhen`
967+
*
968+
* {{{
969+
* scala> import cats.syntax.all._, cats.Foldable
970+
* scala> Foldable[List].splitWhen(List(1,1))(_ == 1)
971+
* res1: List[List[Int]] = List(List(), List(), List())
972+
* scala> Foldable[List].splitWhen(Nil)(_ == 1)
973+
* res2: List[List[Nothing]] = List(List())
974+
* scala> Foldable[List].splitWhen(List(1, 2, 3, 1, 4, 5))(_ == 1)
975+
* res3: List[List[Int]] = List(List(), List(2, 3), List(4, 5))
976+
* }}}
977+
*/
978+
979+
def splitWhen[A](fa: F[A])(f: A => Boolean): List[List[A]] = {
980+
toList(fa).reverse.foldLeft(List.empty[A] :: Nil) { case (lst, e) =>
981+
if (f(e)) Nil :: lst else (e :: lst.head) :: lst.tail
982+
}
983+
}
984+
985+
/**
986+
* Split this Foldable into a List of Lists based on the effectufl predicate. Monadic version of `splitWhen`
987+
*
988+
* {{{
989+
* scala> import cats.syntax.all._, cats.Foldable, cats.Eval
990+
* scala> Foldable[List].splitWhenM(List(1,1))(x => Eval.now(x == 1)).value
991+
* res1: List[List[Int]] = List(List(), List(), List())
992+
* scala> Foldable[List].splitWhenM(List.empty[Int])(x => Eval.now(x == 1)).value
993+
* res2: List[List[Int]] = List(List())
994+
* scala> Foldable[List].splitWhenM(List(1, 2, 3, 1, 4, 5))(x => Eval.now(x == 1)).value
995+
* val res3: List[List[Int]] = List(List(), List(2, 3), List(4, 5))
996+
* }}}
997+
*/
998+
999+
def splitWhenM[G[_], A](fa: F[A])(f: A => G[Boolean])(implicit M: Monad[G]): G[List[List[A]]] = {
1000+
toList(fa).reverse.foldLeft(M.pure(List.empty[A] :: Nil)) { case (acc, e) =>
1001+
M.flatMap(acc) { case lst =>
1002+
M.map(f(e))(if (_) Nil :: lst else (e :: lst.head) :: lst.tail)
1003+
}
1004+
}
1005+
}
9631006
}
9641007

9651008
object Foldable {

core/src/main/scala/cats/syntax/foldable.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,14 @@ final class FoldableOps0[F[_], A](private val fa: F[A]) extends AnyVal {
304304
)(implicit A: Alternative[F], F: Foldable[F], M: Monad[G]): G[(F[B], F[C])] =
305305
F.partitionEitherM[G, A, B, C](fa)(f)(A, M)
306306

307+
def splitWhen(f: A => Boolean)(implicit F: Foldable[F]): List[List[A]] = {
308+
F.splitWhen[A](fa)(f)
309+
}
310+
311+
def splitWhenM[G[_]](f: A => G[Boolean])(implicit F: Foldable[F], G: Monad[G]): G[List[List[A]]] = {
312+
F.splitWhenM[G, A](fa)(f)(G)
313+
}
314+
307315
def sliding2(implicit F: Foldable[F]): List[(A, A)] =
308316
F.sliding2(fa)
309317
def sliding3(implicit F: Foldable[F]): List[(A, A, A)] =

tests/shared/src/test/scala/cats/tests/FoldableSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,29 @@ abstract class FoldableSuite[F[_]: Foldable](name: String)(implicit
333333
}
334334
}
335335

336+
test(s"Foldable[$name].splitWhen") {
337+
forAll { (fa: F[Int]) =>
338+
val pred = (x: Int) => x > 0
339+
val res = fa.splitWhen(pred)
340+
val expectedFiltered = iterator(fa).filterNot(pred).toList
341+
val expectedSize = fa.size - expectedFiltered.size + 1
342+
assert(res.size.toLong === expectedSize)
343+
assert(res.flatten === expectedFiltered)
344+
}
345+
}
346+
347+
test(s"Foldable[$name].splitWhenM") {
348+
forAll { (fa: F[Int]) =>
349+
val pred = (x: Int) => x > 0
350+
val predM = (x: Int) => Eval.now(pred(x))
351+
val res = fa.splitWhenM(predM)
352+
val expectedFiltered = iterator(fa).filterNot(pred).toList
353+
val expectedSize = fa.size - expectedFiltered.size + 1
354+
assert(res.value.size.toLong === expectedSize)
355+
assert(res.value.flatten === expectedFiltered)
356+
}
357+
}
358+
336359
test(s"Foldable[$name].sliding2 consistent with List#sliding(2)") {
337360
forAll { (fi: F[Int]) =>
338361
val n = 2

0 commit comments

Comments
 (0)