Skip to content

Commit

Permalink
adding keys() and values()
Browse files Browse the repository at this point in the history
  • Loading branch information
grundprinzip committed May 17, 2024
1 parent 103913e commit 9e53506
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ that open a pull request and we will review it.
| isEmpty | :x: | |
| isLocallyCheckpointed | :x: | |
| join | :x: | |
| keys | :x: | |
| keys | :white_check_mark: | |
| leftOuterJoin | :x: | |
| localCheckpoint | :x: | |
| lookup | :x: | |
Expand Down Expand Up @@ -159,7 +159,7 @@ that open a pull request and we will review it.
| treeReduce | :x: | |
| union | :x: | |
| unpersist | :x: | |
| values | :x: | |
| values | :white_check_mark: | |
| variance | :x: | |
| withResources | :x: | |
| zip | :x: | |
Expand Down
10 changes: 10 additions & 0 deletions congruity/rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,16 @@ def func(iterator: Iterable[T]) -> Iterable[T]:

fold.__doc__ = RDD.fold.__doc__

def keys(self) -> "RDDAdapter":
return self.map(lambda x: x[0])

keys.__doc__ = RDD.keys.__doc__

def values(self) -> "RDDAdapter":
return self.map(lambda x: x[1])

values.__doc__ = RDD.values.__doc__

class WrappedIterator(Iterable):
"""This is a helper class that wraps the iterator of RecordBatches as returned by
mapInArrow and converts it into an iterator of the underlaying values."""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_public_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ def test_spark_context_parallelize(spark_session: "SparkSession"):
.collect()
)
assert result == [20000, 100000, 3000]

val = spark_session.sparkContext.parallelize(range(0, 5))
assert val.count() == 5
assert val.collect() == [0, 1, 2, 3, 4]
12 changes: 12 additions & 0 deletions tests/test_rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,15 @@ def test_rdd_sum(spark_session: "SparkSession"):

vals = df.rdd.map(lambda x: x[0]).sum()
assert vals == 45


def test_rdd_keys(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize([(1, 2), (3, 4)]).keys()
assert rdd.collect() == [1, 3]


def test_rdd_values(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize([(1, 2), (3, 4)]).values()
assert rdd.collect() == [2, 4]

0 comments on commit 9e53506

Please sign in to comment.