diff --git a/README.md b/README.md index 9219636..073d2f8 100644 --- a/README.md +++ b/README.md @@ -125,15 +125,15 @@ that open a pull request and we will review it. | persist | :x: | | | pipe | :x: | | | randomSplit | :x: | | -| reduce | :x: | | +| reduce | :white_check_mark: | | | reduceByKey | :x: | | | repartition | :x: | | | repartitionAndSortWithinPartition | :x: | | | rightOuterJoin | :x: | | | sample | :x: | | | sampleByKey | :x: | | -| sampleStdev | :x: | | -| sampleVariance | :x: | | +| sampleStdev | :white_check_mark: | | +| sampleVariance | :white_check_mark: | | | saveAsHadoopDataset | :x: | | | saveAsHadoopFile | :x: | | | saveAsNewAPIHadoopDataset | :x: | | @@ -143,8 +143,8 @@ that open a pull request and we will review it. | setName | :x: | | | sortBy | :x: | | | sortByKey | :x: | | -| stats | :x: | | -| stdev | :x: | | +| stats | :white_check_mark: | | +| stdev | :white_check_mark: | | | subtract | :x: | | | substractByKey | :x: | | | sum | :white_check_mark: | First version. | @@ -161,7 +161,7 @@ that open a pull request and we will review it. | union | :x: | | | unpersist | :x: | | | values | :white_check_mark: | | -| variance | :x: | | +| variance | :white_check_mark: | | | withResources | :x: | | | zip | :x: | | | zipWithIndex | :x: | | diff --git a/congruity/rdd_adapter.py b/congruity/rdd_adapter.py index 3e8677f..3f1f80c 100644 --- a/congruity/rdd_adapter.py +++ b/congruity/rdd_adapter.py @@ -30,6 +30,9 @@ Generic, ) +from pyspark.serializers import CloudPickleSerializer +from pyspark.statcounter import StatCounter + T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) U = TypeVar("U") @@ -236,75 +239,56 @@ def take(self: "RDDAdapter[T]", num: int) -> List[T]: def map( self: "RDDApapter[T]", f: Callable[[T], U], preservePartitioning=None ) -> "RDDAdapter[U]": - needs_conversion = self._first_field - schema = RDDAdapter.PA_SCHEMA - - def mapper(iter: Iterable[RecordBatch]): - for b in iter: - result = [] - rows = b.to_pylist() - for r in rows: - if needs_conversion: - val = loads(r["__bin_field__"]) - else: - val = Row(**r) - result.append({"__bin_field__": dumps(f(val))}) - yield RecordBatch.from_pylist(result, schema=schema) + def func(iterator: Iterable[T]) -> Iterable[U]: + return map(fail_on_stopiteration(f), iterator) - result = self._df.mapInArrow(mapper, RDDAdapter.BIN_SCHEMA) - assert len(result.schema.fields) == 1 - return RDDAdapter(result, True) + # This is a diff to the regular map implementation because we don't have + # access to mapPartitionsWithIndex + return self.mapPartitions(func, preservePartitioning) map.__doc__ = RDD.map.__doc__ - def count(self) -> int: - return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() - + count = RDD.count count.__doc__ = RDD.count.__doc__ - def sum(self: "RDDAdapter") -> int: - return self.mapPartitions(lambda x: [sum(x)]).fold( # type: ignore[return-value] - 0, operator.add - ) - + sum = RDD.sum sum.__doc__ = RDD.sum.__doc__ - def fold(self: "RDDAdapter[T]", zeroValue: T, op: Callable[[T, T], T]) -> T: - op = fail_on_stopiteration(op) - - def func(iterator: Iterable[T]) -> Iterable[T]: - acc = zeroValue - for obj in iterator: - acc = op(acc, obj) - yield acc + fold = RDD.fold + fold.__doc__ = RDD.fold.__doc__ - vals = self.mapPartitions(func).collect() - return reduce(op, vals, zeroValue) + keys = RDD.keys + keys.__doc__ = RDD.keys.__doc__ - fold.__doc__ = RDD.fold.__doc__ + values = RDD.values + values.__doc__ = RDD.values.__doc__ - def keys(self: "RDDAdapter[Tuple[K, V]]") -> "RDDAdapter[K]": - return self.map(lambda x: x[0]) + glom = RDD.glom + glom.__doc__ = RDD.glom.__doc__ - keys.__doc__ = RDD.keys.__doc__ + keyBy = RDD.keyBy + keyBy.__doc__ = RDD.keyBy.__doc__ - def values(self: "RDDAdapter[Tuple[K, V]]") -> "RDDAdapter[V]": - return self.map(lambda x: x[1]) + reduce = RDD.reduce + reduce.__doc__ = RDD.reduce.__doc__ - values.__doc__ = RDD.values.__doc__ + stats = RDD.stats + stats.__doc__ = RDD.stats.__doc__ - def glom(self: "RDDAdapter[T]") -> "RDDAdapter[List[T]]": - def func(iterator: Iterable[T]) -> Iterable[List[T]]: - yield list(iterator) + stdev = RDD.stdev + stdev.__doc__ = RDD.stdev.__doc__ - return self.mapPartitions(func) + sampleStdev = RDD.sampleStdev + sampleStdev.__doc__ = RDD.sampleStdev.__doc__ - glom.__doc__ = RDD.glom.__doc__ + sampleVariance = RDD.sampleVariance + sampleVariance.__doc__ = RDD.sampleVariance.__doc__ - def keyBy(self: "RDDAdapter[T]", f: Callable[[T], K]) -> "RDDAdapter[Tuple[K, T]]": - return self.map(lambda x: (f(x), x)) + variance = RDD.variance + variance.__doc__ = RDD.variance.__doc__ - keyBy.__doc__ = RDD.keyBy.__doc__ + aggregate = RDD.aggregate + aggregate.__doc__ = RDD.aggregate.__doc__ class WrappedIterator(Iterable): """This is a helper class that wraps the iterator of RecordBatches as returned by @@ -336,8 +320,45 @@ def __iter__(self): def mapPartitions( self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning=False ) -> "RDDAdapter[U]": + # Every pipeline becomes mapPartitions in the end. So we pass the current RDD as the + # previous reference and the next transformation function. + return Pipeline(self, f) + + mapPartitions.__doc__ = RDD.mapPartitions.__doc__ + + +class Pipeline(RDDAdapter): + + def __init__(self, input: "RDDAdapter", f): + + if isinstance(input, Pipeline): + source = input._prev_source + self._prev_source = source + + prev_fun = input._prev_fun + next_fun = lambda ite: f(prev_fun(ite)) + + self._prev_fun = next_fun + self._prev_first_field = input._prev_first_field + first_field = input._prev_first_field + else: + # Cache the input DF and functions before mapping it. + next_fun = f + self._prev_source = input._df + self._prev_fun = next_fun + self._prev_first_field = input._first_field + first_field = input._first_field + source = input._df + + mapper = self._build_mapper(next_fun, first_field) + # These are the output values of the operations, when a terminal operation is called + # they will be evaluated. + self._df = source.mapInArrow(mapper, RDDAdapter.BIN_SCHEMA) + self._first_field = True + + def _build_mapper(self, f, needs_conversion): + # Fixed constants for the mapPartitions implementation. schema = RDDAdapter.PA_SCHEMA - needs_conversion = self._first_field max_rows_per_batch = 1000 def mapper(iter: Iterable[RecordBatch]): @@ -355,10 +376,4 @@ def mapper(iter: Iterable[RecordBatch]): if len(result) > 0: yield RecordBatch.from_pylist(result, schema=schema) - # MapInArrow is effectively mapPartitions, but streams the rows as batches to the RDD, - # we leverage this fact here and build wrappers for that. - result = self._df.mapInArrow(mapper, RDDAdapter.BIN_SCHEMA) - assert len(result.schema.fields) == 1 - return RDDAdapter(result, True) - - mapPartitions.__doc__ = RDD.mapPartitions.__doc__ + return mapper diff --git a/tests/test_rdd_adapter.py b/tests/test_rdd_adapter.py index 2fbc588..28e5973 100644 --- a/tests/test_rdd_adapter.py +++ b/tests/test_rdd_adapter.py @@ -174,3 +174,57 @@ def test_rdd_key_by(spark_session: "SparkSession"): (16, 8), (18, 9), ] + + +def test_rdd_stats(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + stats = rdd.stats() + assert stats.count() == 10 + assert stats.mean() == 4.5 + assert stats.sum() == 45 + assert stats.min() == 0 + assert stats.max() == 9 + assert stats.stdev() == 2.8722813232690143 + assert stats.variance() == 8.25 + assert stats.sampleStdev() == 3.0276503540974917 + assert stats.sampleVariance() == 9.166666666666666 + + +def test_rdd_stddev(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.stdev() == 2.8722813232690143 + + +def test_rdd_sample_stddev(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.sampleStdev() == 3.0276503540974917 + + +def test_rdd_sample_variance(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.sampleVariance() == 9.166666666666666 + + +def test_rdd_variance(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.variance() == 8.25 + + +def test_rdd_aggregate(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.aggregate(0, lambda x, y: x + y, lambda x, y: x + y) == 45 + + seqOp = lambda x, y: (x[0] + y, x[1] + 1) + combOp = lambda x, y: (x[0] + y[0], x[1] + y[1]) + res = spark_session.sparkContext.parallelize([1, 2, 3, 4]).aggregate((0, 0), seqOp, combOp) + assert res == (10, 4) + + # TODO empty + # res = spark_session.sparkContext.parallelize([]).aggregate((0, 0), seqOp, combOp) + # assert res == (0, 0)