From 3f7dd4ddc49fa6f1fbcc0a9936589418c8a3e935 Mon Sep 17 00:00:00 2001 From: "David R. Bild" Date: Wed, 9 Aug 2017 12:59:53 -0500 Subject: [PATCH] Add iterateWhileM and iterateUntilM --- core/src/main/scala/cats/Monad.scala | 19 +++++++++++++++ core/src/main/scala/cats/syntax/monad.scala | 16 +++++++++++++ .../src/test/scala/cats/tests/MonadTest.scala | 24 +++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index 3b65ef763c..f3f364a2a0 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -100,4 +100,23 @@ import simulacrum.typeclass } } + /** + * Apply a monadic function iteratively until its result fails + * to satisfy the given predicate and return that result. + */ + def iterateWhileM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] = + tailRecM(init) { a => + if (p(a)) + map(f(a))(Left(_)) + else + pure(Right(a)) + } + + /** + * Apply a monadic function iteratively until its result satisfies + * the given predicate and return that result. + */ + def iterateUntilM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] = + iterateWhileM(init)(f)(!p(_)) + } diff --git a/core/src/main/scala/cats/syntax/monad.scala b/core/src/main/scala/cats/syntax/monad.scala index 2772c5a8ed..efdd15235f 100644 --- a/core/src/main/scala/cats/syntax/monad.scala +++ b/core/src/main/scala/cats/syntax/monad.scala @@ -3,6 +3,9 @@ package syntax trait MonadSyntax { implicit final def catsSyntaxMonad[F[_], A](fa: F[A]): MonadOps[F, A] = new MonadOps(fa) + + implicit final def catsSyntaxMonadIdOps[A](a: A): MonadIdOps[A] = + new MonadIdOps[A](a) } final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal { @@ -13,3 +16,16 @@ final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal { def iterateWhile(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhile(fa)(p) def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p) } + +final class MonadIdOps[A](val a: A) extends AnyVal { + + /** + * Iterative application of `f` while `p` holds. + */ + def iterateWhileM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhileM(a)(f)(p) + + /** + * Iterative application of `f` until `p` holds. + */ + def iterateUntilM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntilM(a)(f)(p) +} diff --git a/tests/src/test/scala/cats/tests/MonadTest.scala b/tests/src/test/scala/cats/tests/MonadTest.scala index 75c9ccdc55..d38287556d 100644 --- a/tests/src/test/scala/cats/tests/MonadTest.scala +++ b/tests/src/test/scala/cats/tests/MonadTest.scala @@ -76,4 +76,28 @@ class MonadTest extends CatsSuite { result should ===(50000) } + test("iterateWhileM") { + forAll(smallPosInt) { (max: Int) => + val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < max).run(0) + sum should ===(n * (n + 1) / 2) + } + } + + test("iterateWhileM is stack safe") { + val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < 50000000).run(0) + sum should ===(n * (n + 1) / 2) + } + + test("iterateUntilM") { + forAll(smallPosInt) { (max: Int) => + val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > max).run(0) + sum should ===(n * (n + 1) / 2) + } + } + + test("iterateUntilM is stack safe") { + val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > 50000000).run(0) + sum should ===(n * (n + 1) / 2) + } + }