diff --git a/summingbird-core-test/src/test/scala/com/twitter/summingbird/memory/MemoryLaws.scala b/summingbird-core-test/src/test/scala/com/twitter/summingbird/memory/MemoryLaws.scala index 3460d98cc..b9ccd3f07 100644 --- a/summingbird-core-test/src/test/scala/com/twitter/summingbird/memory/MemoryLaws.scala +++ b/summingbird-core-test/src/test/scala/com/twitter/summingbird/memory/MemoryLaws.scala @@ -16,7 +16,7 @@ limitations under the License. package com.twitter.summingbird.memory -import com.twitter.algebird.{ MapAlgebra, Monoid, Semigroup } +import com.twitter.algebird.{ Aggregator, MapAlgebra, Monoid, Semigroup } import com.twitter.summingbird._ import com.twitter.summingbird.option.JobId import org.scalacheck.{ Arbitrary, _ } @@ -214,6 +214,32 @@ class MemoryLaws extends WordSpec { assert(store1.toMap == ((0 to 100).groupBy(_ % 3).mapValues(_.sum))) assert(store2.toMap == ((0 to 100).groupBy(_ % 3).mapValues(_.sum))) } + "aggregate should work" in { + val source = Memory.toSource((0 to 100).reverse) + val store = MutableMap.empty[Int, Int] + val buf = MutableMap.empty[Int, List[(Option[Int], Int)]] + val prod = source.map { t => (t % 2, t) } + .aggregate(store, Aggregator.max[Int].andThenPresent(_ * 2).composePrepare(_ / 2)) + .write { kv => + val (k, vs) = kv + buf(k) = vs :: buf.getOrElse(k, Nil) + } + val mem = new Memory + mem.run(mem.plan(prod)) + + assert(store.keySet == Set(0, 1)) + assert(store(0) == (0 to 100).filter(_ % 2 == 0).map(_ / 2).max) + assert(store(1) == (0 to 100).filter(_ % 2 == 1).map(_ / 2).max) + assert(buf.keySet == Set(0, 1)) + assert(buf(0).map(_._2) == + (0 to 100).reverse.filter(_ % 2 == 0).map { t => (t / 2) * 2 }.toList) + assert(buf(0).map(_._1) == + (None :: ((0 to 100).reverse.filter(_ % 2 == 0).map { t => Some((t / 2)*2) }.toList))) + assert(buf(1).map(_._2) == + (0 to 100).reverse.filter(_ % 2 == 1).map { t => (t / 2) * 2 }.toList) + assert(buf(1).map(_._1) == + (None :: ((0 to 100).reverse.filter(_ % 2 == 1).map { t => Some((t / 2)*2) }.toList))) + } "self also shouldn't duplicate work" in { val platform = new Memory diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/Producer.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/Producer.scala index 6ff703cd3..9aa7fe4bf 100644 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/Producer.scala +++ b/summingbird-core/src/main/scala/com/twitter/summingbird/Producer.scala @@ -16,7 +16,7 @@ package com.twitter.summingbird -import com.twitter.algebird.Semigroup +import com.twitter.algebird.{ Aggregator, Semigroup } object Producer { @@ -251,6 +251,23 @@ case class Summer[P <: Platform[P], K, V]( */ sealed trait KeyedProducer[P <: Platform[P], K, V] extends Producer[P, (K, V)] { + /** + * This applies an Aggregator to the values. The result type is similar to sumByKey with + * a crucial difference: the tuple is Option(previous aggregated value), current aggregated value + * in sumByKey you get previous and the delta, but after agg.present, the delta cannot be combined + * and is not meaningful in the general case. + */ + def aggregate[V1, V2](store: P#Store[K, V1], agg: Aggregator[V, V1, V2]): KeyedProducer[P, K, (Option[V2], V2)] = { + val sg = agg.semigroup + mapValues(agg.prepare) + .sumByKey(store)(sg) + .mapValues { + case (optv1, v1) => + val resultv1 = if (optv1.isDefined) sg.plus(optv1.get, v1) else v1 + (optv1.map(agg.present), agg.present(resultv1)) + } + } + /** Builds a new KeyedProvider by applying a partial function to keys of elements of this one on which the function is defined.*/ def collectKeys[K2](pf: PartialFunction[K, K2]): KeyedProducer[P, K2, V] = IdentityKeyedProducer(collect { case (k, v) if pf.isDefinedAt(k) => (pf(k), v) })