diff --git a/pysparkling/rdd.py b/pysparkling/rdd.py index 56cd3dc09..5eb1e8153 100644 --- a/pysparkling/rdd.py +++ b/pysparkling/rdd.py @@ -656,6 +656,39 @@ def pipe(self, command, env={}): [command]+x if isinstance(x, list) else [command, x] ) for x in self.collect()) + def randomSplit(self, weights, seed=None): + """ + Split the RDD into a few RDDs according to the given weights. + + .. note:: + Creating the new RDDs is currently implemented as a local + operation. + + :param weights: + Determines the relative lengths of the resulting RDDs. + + :param seed: + Seed for random number generator. + + :returns: + A list of RDDs. + + """ + sum_weights = sum(weights) + boundaries = [0] + for w in weights: + boundaries.append(boundaries[-1] + w/sum_weights) + random.seed(seed) + + lists = [[] for _ in weights] + for e in self.toLocalIterator(): + r = random.random() + for i, lbub in enumerate(zip(boundaries[:-1], boundaries[1:])): + if r >= lbub[0] and r < lbub[1]: + lists[i].append(e) + + return [self.context.parallelize(l) for l in lists] + def reduce(self, f): """ :param f: diff --git a/tests/test_rdd_unit.py b/tests/test_rdd_unit.py index d7b0b103d..20c403a67 100644 --- a/tests/test_rdd_unit.py +++ b/tests/test_rdd_unit.py @@ -223,6 +223,13 @@ def test_pipe(): assert b'hello\n' in piped +def test_randomSplit(): + rdd = Context().parallelize(range(500)) + rdd1, rdd2 = rdd.randomSplit([2, 3], seed=42) + print(rdd1.count(), rdd2.count()) + assert rdd1.count() == 199 and rdd2.count() == 301 + + def test_reduce(): rdd = Context().parallelize([0, 4, 7, 4, 10]) assert rdd.reduce(lambda a, b: a+b) == 25 @@ -334,4 +341,4 @@ def test_zipWithUniqueIndex(): if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) - test_first_empty_partitions() + test_randomSplit()