diff --git a/core/src/main/scala/cats/MonadError.scala b/core/src/main/scala/cats/MonadError.scala index bed0990728..4541a44751 100644 --- a/core/src/main/scala/cats/MonadError.scala +++ b/core/src/main/scala/cats/MonadError.scala @@ -1,13 +1,79 @@ package cats -/** A monad that also allows you to raise and or handle an error value. - * - * This type class allows one to abstract over error-handling monads. - */ +import cats.data.{Xor, XorT} + +/** + * A monad that also allows you to raise and or handle an error value. + * + * This type class allows one to abstract over error-handling monads. + */ trait MonadError[F[_], E] extends Monad[F] { + /** + * Lift an error into the `F` context. + */ def raiseError[A](e: E): F[A] - def handleError[A](fea: F[A])(f: E => F[A]): F[A] + /** + * Handle any error, potentially recovering from it, by mapping it to an + * `F[A]` value. + * + * @see [[handleError]] to handle any error by simply mapping it to an `A` + * value instead of an `F[A]`. + * + * @see [[recoverWith]] to recover from only certain errors. + */ + def handleErrorWith[A](fa: F[A])(f: E => F[A]): F[A] + + /** + * Handle any error, by mapping it to an `A` value. + * + * @see [[handleErrorWith]] to map to an `F[A]` value instead of simply an + * `A` value. + * + * @see [[recover]] to only recover from certain errors. + */ + def handleError[A](fa: F[A])(f: E => A): F[A] = handleErrorWith(fa)(f andThen pure) + + /** + * Handle errors by turning them into [[cats.data.Xor.Left]] values. + * + * If there is no error, then an [[cats.data.Xor.Right]] value will be returned instead. + * + * All non-fatal errors should be handled by this method. + */ + def attempt[A](fa: F[A]): F[E Xor A] = handleErrorWith( + map(fa)(Xor.right[E, A]) + )(e => pure(Xor.left(e))) + + /** + * Similar to [[attempt]], but wraps the result in a [[cats.data.XorT]] for + * convenience. + */ + def attemptT[A](fa: F[A]): XorT[F, E, A] = XorT(attempt(fa)) + + /** + * Recover from certain errors by mapping them to an `A` value. + * + * @see [[handleError]] to handle any/all errors. + * + * @see [[recoverWith]] to recover from certain errors by mapping them to + * `F[A]` values. + */ + def recover[A](fa: F[A])(pf: PartialFunction[E, A]): F[A] = + handleErrorWith(fa)(e => + (pf andThen pure) applyOrElse(e, raiseError)) + + /** + * Recover from certain errors by mapping them to an `F[A]` value. + * + * @see [[handleErrorWith]] to handle any/all errors. + * + * @see [[recover]] to recover from certain errors by mapping them to `A` + * values. + */ + def recoverWith[A](fa: F[A])(pf: PartialFunction[E, F[A]]): F[A] = + handleErrorWith(fa)(e => + pf applyOrElse(e, raiseError)) } object MonadError { diff --git a/core/src/main/scala/cats/data/Xor.scala b/core/src/main/scala/cats/data/Xor.scala index ee15175da0..083d4fbd51 100644 --- a/core/src/main/scala/cats/data/Xor.scala +++ b/core/src/main/scala/cats/data/Xor.scala @@ -176,13 +176,18 @@ sealed abstract class XorInstances extends XorInstances1 { def foldRight[B, C](fa: A Xor B, lc: Eval[C])(f: (B, Eval[C]) => Eval[C]): Eval[C] = fa.foldRight(lc)(f) def flatMap[B, C](fa: A Xor B)(f: B => A Xor C): A Xor C = fa.flatMap(f) def pure[B](b: B): A Xor B = Xor.right(b) - def handleError[B](fea: Xor[A, B])(f: A => Xor[A, B]): Xor[A, B] = + def handleErrorWith[B](fea: Xor[A, B])(f: A => Xor[A, B]): Xor[A, B] = fea match { case Xor.Left(e) => f(e) case r @ Xor.Right(_) => r } def raiseError[B](e: A): Xor[A, B] = Xor.left(e) override def map[B, C](fa: A Xor B)(f: B => C): A Xor C = fa.map(f) + override def attempt[B](fab: A Xor B): A Xor (A Xor B) = Xor.right(fab) + override def recover[B](fab: A Xor B)(pf: PartialFunction[A, B]): A Xor B = + fab recover pf + override def recoverWith[B](fab: A Xor B)(pf: PartialFunction[A, A Xor B]): A Xor B = + fab recoverWith pf } } diff --git a/core/src/main/scala/cats/data/XorT.scala b/core/src/main/scala/cats/data/XorT.scala index 45ffda62f8..0fab1b522b 100644 --- a/core/src/main/scala/cats/data/XorT.scala +++ b/core/src/main/scala/cats/data/XorT.scala @@ -216,14 +216,22 @@ private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] implicit val F: Monad[F] def pure[A](a: A): XorT[F, L, A] = XorT.pure[F, L, A](a) def flatMap[A, B](fa: XorT[F, L, A])(f: A => XorT[F, L, B]): XorT[F, L, B] = fa flatMap f - def handleError[A](fea: XorT[F, L, A])(f: L => XorT[F, L, A]): XorT[F, L, A] = + def handleErrorWith[A](fea: XorT[F, L, A])(f: L => XorT[F, L, A]): XorT[F, L, A] = XorT(F.flatMap(fea.value) { - _ match { - case Xor.Left(e) => f(e).value - case r @ Xor.Right(_) => F.pure(r) - } + case Xor.Left(e) => f(e).value + case r @ Xor.Right(_) => F.pure(r) + }) + override def handleError[A](fea: XorT[F, L, A])(f: L => A): XorT[F, L, A] = + XorT(F.flatMap(fea.value) { + case Xor.Left(e) => F.pure(Xor.Right(f(e))) + case r @ Xor.Right(_) => F.pure(r) }) def raiseError[A](e: L): XorT[F, L, A] = XorT.left(F.pure(e)) + override def attempt[A](fla: XorT[F, L, A]): XorT[F, L, L Xor A] = XorT.right(fla.value) + override def recover[A](fla: XorT[F, L, A])(pf: PartialFunction[L, A]): XorT[F, L, A] = + fla.recover(pf) + override def recoverWith[A](fla: XorT[F, L, A])(pf: PartialFunction[L, XorT[F, L, A]]): XorT[F, L, A] = + fla.recoverWith(pf) } private[data] trait XorTSemigroupK[F[_], L] extends SemigroupK[XorT[F, L, ?]] { diff --git a/core/src/main/scala/cats/std/future.scala b/core/src/main/scala/cats/std/future.scala index 26ef6638e1..e689ed822c 100644 --- a/core/src/main/scala/cats/std/future.scala +++ b/core/src/main/scala/cats/std/future.scala @@ -2,7 +2,9 @@ package cats package std import cats.syntax.all._ +import cats.data.Xor +import scala.util.control.NonFatal import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.FiniteDuration @@ -16,9 +18,17 @@ trait FutureInstances extends FutureInstances1 { def flatMap[A, B](fa: Future[A])(f: A => Future[B]): Future[B] = fa.flatMap(f) - def handleError[A](fea: Future[A])(f: Throwable => Future[A]): Future[A] = fea.recoverWith { case t => f(t) } + def handleErrorWith[A](fea: Future[A])(f: Throwable => Future[A]): Future[A] = fea.recoverWith { case t => f(t) } def raiseError[A](e: Throwable): Future[A] = Future.failed(e) + override def handleError[A](fea: Future[A])(f: Throwable => A): Future[A] = fea.recover { case t => f(t) } + + override def attempt[A](fa: Future[A]): Future[Throwable Xor A] = + (fa map Xor.right) recover { case NonFatal(t) => Xor.left(t) } + + override def recover[A](fa: Future[A])(pf: PartialFunction[Throwable, A]): Future[A] = fa.recover(pf) + + override def recoverWith[A](fa: Future[A])(pf: PartialFunction[Throwable, Future[A]]): Future[A] = fa.recoverWith(pf) override def map[A, B](fa: Future[A])(f: A => B): Future[B] = fa.map(f) } diff --git a/js/src/test/scala/cats/tests/FutureTests.scala b/js/src/test/scala/cats/tests/FutureTests.scala index 47e6cf484a..9460682575 100644 --- a/js/src/test/scala/cats/tests/FutureTests.scala +++ b/js/src/test/scala/cats/tests/FutureTests.scala @@ -10,7 +10,6 @@ import cats.tests.CatsSuite import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import scala.util.control.NonFatal import scala.scalajs.concurrent.JSExecutionContext.Implicits.runNow diff --git a/jvm/src/test/scala/cats/tests/FutureTests.scala b/jvm/src/test/scala/cats/tests/FutureTests.scala index 9b3f68c169..9b5e9bba0d 100644 --- a/jvm/src/test/scala/cats/tests/FutureTests.scala +++ b/jvm/src/test/scala/cats/tests/FutureTests.scala @@ -10,7 +10,6 @@ import cats.tests.CatsSuite import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global -import scala.util.control.NonFatal import org.scalacheck.Arbitrary import org.scalacheck.Arbitrary.arbitrary diff --git a/laws/src/main/scala/cats/laws/MonadErrorLaws.scala b/laws/src/main/scala/cats/laws/MonadErrorLaws.scala index 26723da911..3a2acbaf44 100644 --- a/laws/src/main/scala/cats/laws/MonadErrorLaws.scala +++ b/laws/src/main/scala/cats/laws/MonadErrorLaws.scala @@ -1,6 +1,8 @@ package cats package laws +import cats.data.{Xor, XorT} + // Taken from http://functorial.com/psc-pages/docs/Control/Monad/Error/Class/index.html trait MonadErrorLaws[F[_], E] extends MonadLaws[F] { implicit override def F: MonadError[F, E] @@ -8,11 +10,35 @@ trait MonadErrorLaws[F[_], E] extends MonadLaws[F] { def monadErrorLeftZero[A, B](e: E, f: A => F[B]): IsEq[F[B]] = F.flatMap(F.raiseError[A](e))(f) <-> F.raiseError[B](e) - def monadErrorHandle[A](e: E, f: E => F[A]): IsEq[F[A]] = - F.handleError(F.raiseError[A](e))(f) <-> f(e) + def monadErrorHandleWith[A](e: E, f: E => F[A]): IsEq[F[A]] = + F.handleErrorWith(F.raiseError[A](e))(f) <-> f(e) + + def monadErrorHandle[A](e: E, f: E => A): IsEq[F[A]] = + F.handleError(F.raiseError[A](e))(f) <-> F.pure(f(e)) + + def handleErrorWithPure[A](a: A, f: E => F[A]): IsEq[F[A]] = + F.handleErrorWith(F.pure(a))(f) <-> F.pure(a) - def monadErrorPure[A](a: A, f: E => F[A]): IsEq[F[A]] = + def handleErrorPure[A](a: A, f: E => A): IsEq[F[A]] = F.handleError(F.pure(a))(f) <-> F.pure(a) + + def raiseErrorAttempt(e: E): IsEq[F[E Xor Unit]] = + F.attempt(F.raiseError[Unit](e)) <-> F.pure(Xor.left(e)) + + def pureAttempt[A](a: A): IsEq[F[E Xor A]] = + F.attempt(F.pure(a)) <-> F.pure(Xor.right(a)) + + def handleErrorWithConsistentWithRecoverWith[A](fa: F[A], f: E => F[A]): IsEq[F[A]] = + F.handleErrorWith(fa)(f) <-> F.recoverWith(fa)(PartialFunction(f)) + + def handleErrorConsistentWithRecover[A](fa: F[A], f: E => A): IsEq[F[A]] = + F.handleError(fa)(f) <-> F.recover(fa)(PartialFunction(f)) + + def recoverConsistentWithRecoverWith[A](fa: F[A], pf: PartialFunction[E, A]): IsEq[F[A]] = + F.recover(fa)(pf) <-> F.recoverWith(fa)(pf andThen F.pure) + + def attemptConsistentWithAttemptT[A](fa: F[A]): IsEq[XorT[F, E, A]] = + XorT(F.attempt(fa)) <-> F.attemptT(fa) } object MonadErrorLaws { diff --git a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala index 06397ac4b3..25f7320c89 100644 --- a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala +++ b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala @@ -78,4 +78,8 @@ object arbitrary { implicit def writerTArbitrary[F[_], L, V](implicit F: Arbitrary[F[(L, V)]]): Arbitrary[WriterT[F, L, V]] = Arbitrary(F.arbitrary.map(WriterT(_))) + + // until this is provided by scalacheck + implicit def partialFunctionArbitrary[A, B](implicit F: Arbitrary[A => Option[B]]): Arbitrary[PartialFunction[A, B]] = + Arbitrary(F.arbitrary.map(Function.unlift)) } diff --git a/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala b/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala index e9cf492aca..5bde73b282 100644 --- a/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala @@ -2,6 +2,8 @@ package cats package laws package discipline +import cats.laws.discipline.arbitrary.partialFunctionArbitrary +import cats.laws.discipline.eq.unitEq import org.scalacheck.{Arbitrary, Prop} import org.scalacheck.Prop.forAll @@ -15,10 +17,8 @@ trait MonadErrorTests[F[_], E] extends MonadTests[F] { implicit def eqE: Eq[E] def monadError[A: Arbitrary: Eq, B: Arbitrary: Eq, C: Arbitrary: Eq]: RuleSet = { - implicit def ArbFEA: Arbitrary[F[A]] = arbitraryK.synthesize[A] - implicit def ArbFEB: Arbitrary[F[B]] = arbitraryK.synthesize[B] - implicit def EqFEA: Eq[F[A]] = eqK.synthesize[A] - implicit def EqFEB: Eq[F[B]] = eqK.synthesize[B] + implicit def arbFT[T:Arbitrary]: Arbitrary[F[T]] = arbitraryK.synthesize + implicit def eqFT[T:Eq]: Eq[F[T]] = eqK.synthesize new RuleSet { def name: String = "monadError" @@ -26,8 +26,16 @@ trait MonadErrorTests[F[_], E] extends MonadTests[F] { def parents: Seq[RuleSet] = Seq(monad[A, B, C]) def props: Seq[(String, Prop)] = Seq( "monadError left zero" -> forAll(laws.monadErrorLeftZero[A, B] _), - "monadError handle" -> forAll(laws.monadErrorHandle[A] _), - "monadError pure" -> forAll(laws.monadErrorPure[A] _) + "monadError handleWith" -> forAll(laws.monadErrorHandleWith[A] _), + "monadError handle" -> forAll(laws.monadErrorHandleWith[A] _), + "monadError handleErrorWith pure" -> forAll(laws.handleErrorWithPure[A] _), + "monadError handleError pure" -> forAll(laws.handleErrorPure[A] _), + "monadError raiseError attempt" -> forAll(laws.raiseErrorAttempt _), + "monadError pure attempt" -> forAll(laws.pureAttempt[A] _), + "monadError handleErrorWith consistent with recoverWith" -> forAll(laws.handleErrorWithConsistentWithRecoverWith[A] _), + "monadError handleError consistent with recover" -> forAll(laws.handleErrorConsistentWithRecover[A] _), + "monadError recover consistent with recoverWith" -> forAll(laws.recoverConsistentWithRecoverWith[A] _), + "monadError attempt consistent with attemptT" -> forAll(laws.attemptConsistentWithAttemptT[A] _) ) } }