Skip to content

Commit

Permalink
Making the mapPartitions function pipelinable
Browse files Browse the repository at this point in the history
  • Loading branch information
grundprinzip committed May 20, 2024
1 parent f2013b7 commit 7895817
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 64 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ that open a pull request and we will review it.
| persist | :x: | |
| pipe | :x: | |
| randomSplit | :x: | |
| reduce | :x: | |
| reduce | :white_check_mark: | |
| reduceByKey | :x: | |
| repartition | :x: | |
| repartitionAndSortWithinPartition | :x: | |
| rightOuterJoin | :x: | |
| sample | :x: | |
| sampleByKey | :x: | |
| sampleStdev | :x: | |
| sampleVariance | :x: | |
| sampleStdev | :white_check_mark: | |
| sampleVariance | :white_check_mark: | |
| saveAsHadoopDataset | :x: | |
| saveAsHadoopFile | :x: | |
| saveAsNewAPIHadoopDataset | :x: | |
Expand All @@ -143,8 +143,8 @@ that open a pull request and we will review it.
| setName | :x: | |
| sortBy | :x: | |
| sortByKey | :x: | |
| stats | :x: | |
| stdev | :x: | |
| stats | :white_check_mark: | |
| stdev | :white_check_mark: | |
| subtract | :x: | |
| substractByKey | :x: | |
| sum | :white_check_mark: | First version. |
Expand All @@ -161,7 +161,7 @@ that open a pull request and we will review it.
| union | :x: | |
| unpersist | :x: | |
| values | :white_check_mark: | |
| variance | :x: | |
| variance | :white_check_mark: | |
| withResources | :x: | |
| zip | :x: | |
| zipWithIndex | :x: | |
Expand Down
131 changes: 73 additions & 58 deletions congruity/rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
Generic,
)

from pyspark.serializers import CloudPickleSerializer
from pyspark.statcounter import StatCounter

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
U = TypeVar("U")
Expand Down Expand Up @@ -236,75 +239,56 @@ def take(self: "RDDAdapter[T]", num: int) -> List[T]:
def map(
self: "RDDApapter[T]", f: Callable[[T], U], preservePartitioning=None
) -> "RDDAdapter[U]":
needs_conversion = self._first_field
schema = RDDAdapter.PA_SCHEMA

def mapper(iter: Iterable[RecordBatch]):
for b in iter:
result = []
rows = b.to_pylist()
for r in rows:
if needs_conversion:
val = loads(r["__bin_field__"])
else:
val = Row(**r)
result.append({"__bin_field__": dumps(f(val))})
yield RecordBatch.from_pylist(result, schema=schema)
def func(iterator: Iterable[T]) -> Iterable[U]:
return map(fail_on_stopiteration(f), iterator)

result = self._df.mapInArrow(mapper, RDDAdapter.BIN_SCHEMA)
assert len(result.schema.fields) == 1
return RDDAdapter(result, True)
# This is a diff to the regular map implementation because we don't have
# access to mapPartitionsWithIndex
return self.mapPartitions(func, preservePartitioning)

map.__doc__ = RDD.map.__doc__

def count(self) -> int:
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()

count = RDD.count
count.__doc__ = RDD.count.__doc__

def sum(self: "RDDAdapter") -> int:
return self.mapPartitions(lambda x: [sum(x)]).fold( # type: ignore[return-value]
0, operator.add
)

sum = RDD.sum
sum.__doc__ = RDD.sum.__doc__

def fold(self: "RDDAdapter[T]", zeroValue: T, op: Callable[[T, T], T]) -> T:
op = fail_on_stopiteration(op)

def func(iterator: Iterable[T]) -> Iterable[T]:
acc = zeroValue
for obj in iterator:
acc = op(acc, obj)
yield acc
fold = RDD.fold
fold.__doc__ = RDD.fold.__doc__

vals = self.mapPartitions(func).collect()
return reduce(op, vals, zeroValue)
keys = RDD.keys
keys.__doc__ = RDD.keys.__doc__

fold.__doc__ = RDD.fold.__doc__
values = RDD.values
values.__doc__ = RDD.values.__doc__

def keys(self: "RDDAdapter[Tuple[K, V]]") -> "RDDAdapter[K]":
return self.map(lambda x: x[0])
glom = RDD.glom
glom.__doc__ = RDD.glom.__doc__

keys.__doc__ = RDD.keys.__doc__
keyBy = RDD.keyBy
keyBy.__doc__ = RDD.keyBy.__doc__

def values(self: "RDDAdapter[Tuple[K, V]]") -> "RDDAdapter[V]":
return self.map(lambda x: x[1])
reduce = RDD.reduce
reduce.__doc__ = RDD.reduce.__doc__

values.__doc__ = RDD.values.__doc__
stats = RDD.stats
stats.__doc__ = RDD.stats.__doc__

def glom(self: "RDDAdapter[T]") -> "RDDAdapter[List[T]]":
def func(iterator: Iterable[T]) -> Iterable[List[T]]:
yield list(iterator)
stdev = RDD.stdev
stdev.__doc__ = RDD.stdev.__doc__

return self.mapPartitions(func)
sampleStdev = RDD.sampleStdev
sampleStdev.__doc__ = RDD.sampleStdev.__doc__

glom.__doc__ = RDD.glom.__doc__
sampleVariance = RDD.sampleVariance
sampleVariance.__doc__ = RDD.sampleVariance.__doc__

def keyBy(self: "RDDAdapter[T]", f: Callable[[T], K]) -> "RDDAdapter[Tuple[K, T]]":
return self.map(lambda x: (f(x), x))
variance = RDD.variance
variance.__doc__ = RDD.variance.__doc__

keyBy.__doc__ = RDD.keyBy.__doc__
aggregate = RDD.aggregate
aggregate.__doc__ = RDD.aggregate.__doc__

class WrappedIterator(Iterable):
"""This is a helper class that wraps the iterator of RecordBatches as returned by
Expand Down Expand Up @@ -336,8 +320,45 @@ def __iter__(self):
def mapPartitions(
self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning=False
) -> "RDDAdapter[U]":
# Every pipeline becomes mapPartitions in the end. So we pass the current RDD as the
# previous reference and the next transformation function.
return Pipeline(self, f)

mapPartitions.__doc__ = RDD.mapPartitions.__doc__


class Pipeline(RDDAdapter):

def __init__(self, input: "RDDAdapter", f):

if isinstance(input, Pipeline):
source = input._prev_source
self._prev_source = source

prev_fun = input._prev_fun
next_fun = lambda ite: f(prev_fun(ite))

self._prev_fun = next_fun
self._prev_first_field = input._prev_first_field
first_field = input._prev_first_field
else:
# Cache the input DF and functions before mapping it.
next_fun = f
self._prev_source = input._df
self._prev_fun = next_fun
self._prev_first_field = input._first_field
first_field = input._first_field
source = input._df

mapper = self._build_mapper(next_fun, first_field)
# These are the output values of the operations, when a terminal operation is called
# they will be evaluated.
self._df = source.mapInArrow(mapper, RDDAdapter.BIN_SCHEMA)
self._first_field = True

def _build_mapper(self, f, needs_conversion):
# Fixed constants for the mapPartitions implementation.
schema = RDDAdapter.PA_SCHEMA
needs_conversion = self._first_field
max_rows_per_batch = 1000

def mapper(iter: Iterable[RecordBatch]):
Expand All @@ -355,10 +376,4 @@ def mapper(iter: Iterable[RecordBatch]):
if len(result) > 0:
yield RecordBatch.from_pylist(result, schema=schema)

# MapInArrow is effectively mapPartitions, but streams the rows as batches to the RDD,
# we leverage this fact here and build wrappers for that.
result = self._df.mapInArrow(mapper, RDDAdapter.BIN_SCHEMA)
assert len(result.schema.fields) == 1
return RDDAdapter(result, True)

mapPartitions.__doc__ = RDD.mapPartitions.__doc__
return mapper
54 changes: 54 additions & 0 deletions tests/test_rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,57 @@ def test_rdd_key_by(spark_session: "SparkSession"):
(16, 8),
(18, 9),
]


def test_rdd_stats(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
stats = rdd.stats()
assert stats.count() == 10
assert stats.mean() == 4.5
assert stats.sum() == 45
assert stats.min() == 0
assert stats.max() == 9
assert stats.stdev() == 2.8722813232690143
assert stats.variance() == 8.25
assert stats.sampleStdev() == 3.0276503540974917
assert stats.sampleVariance() == 9.166666666666666


def test_rdd_stddev(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.stdev() == 2.8722813232690143


def test_rdd_sample_stddev(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.sampleStdev() == 3.0276503540974917


def test_rdd_sample_variance(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.sampleVariance() == 9.166666666666666


def test_rdd_variance(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.variance() == 8.25


def test_rdd_aggregate(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.aggregate(0, lambda x, y: x + y, lambda x, y: x + y) == 45

seqOp = lambda x, y: (x[0] + y, x[1] + 1)
combOp = lambda x, y: (x[0] + y[0], x[1] + y[1])
res = spark_session.sparkContext.parallelize([1, 2, 3, 4]).aggregate((0, 0), seqOp, combOp)
assert res == (10, 4)

# TODO empty
# res = spark_session.sparkContext.parallelize([]).aggregate((0, 0), seqOp, combOp)
# assert res == (0, 0)

0 comments on commit 7895817

Please sign in to comment.