diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Monad.scala b/algebird-core/src/main/scala/com/twitter/algebird/Monad.scala index 16fecd6e4..04a56f310 100644 --- a/algebird-core/src/main/scala/com/twitter/algebird/Monad.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/Monad.scala @@ -19,6 +19,7 @@ import java.lang.{Integer => JInt, Short => JShort, Long => JLong, Float => JFlo import java.util.{List => JList, Map => JMap} import scala.annotation.implicitNotFound +import collection.GenTraversable /** * Simple implementation of a Monad type-class */ @@ -37,6 +38,7 @@ trait Monad[M[_]] { // flatMap(flatMap(m)(f))(g) == flatMap(m) { x => flatMap(f(x))(g) } } + /** For use from Java/minimizing code bloat in scala */ abstract class AbstractMonad[M[_]] extends Monad[M] @@ -48,6 +50,11 @@ object Monad { def apply[M[_]](implicit monad: Monad[M]): Monad[M] = monad def flatMap[M[_],T,U](m: M[T])(fn: (T) => M[U])(implicit monad: Monad[M]) = monad.flatMap(m)(fn) def map[M[_],T,U](m: M[T])(fn: (T) => U)(implicit monad: Monad[M]) = monad.map(m)(fn) + def foldM[M[_],T,U](acc: T, xs: GenTraversable[U])(fn: (T,U)=>M[T])(implicit monad: Monad[M]) : M[T] = + if(xs.isEmpty) + monad.apply(acc) + else + monad.flatMap(fn(acc,xs.head)){t: T => foldM(t, xs.tail)(fn)} // Some instances of the Monad typeclass (case for a macro): implicit val list: Monad[List] = new Monad[List] { diff --git a/algebird-test/src/test/scala/com/twitter/algebird/MonadFoldMTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/MonadFoldMTest.scala new file mode 100644 index 000000000..5e7176a39 --- /dev/null +++ b/algebird-test/src/test/scala/com/twitter/algebird/MonadFoldMTest.scala @@ -0,0 +1,26 @@ +package com.twitter.algebird + +import org.specs._ + +class MonadFoldMTest extends Specification { + noDetailedDiffs() + + def binSmalls(x: Int, y: Int) : Option[Int] = if(y > 9) None else Some(x+y) + "A monad foldM" should { + "fold correctly" in { + + // nice easy example from Learn You a Haskell + + val first = Monad.foldM(0,List(2,8,3,1))(binSmalls) + first must be_==(Some(14)) + def binSmalls2(x: Int, y: String) : Option[Int] = if(y == "11") None else Some(x+y.toInt) + + val second = Monad.foldM(0, List("2","11","3","1"))(binSmalls2) + second must be_==(None) + } + "handle an empty list" in { + val third = Monad.foldM(0,List.empty)(binSmalls) + third must be_==(Some(0)) + } + } +} \ No newline at end of file