From 0248b69a735b68ce64cbd3641919c02eb2110986 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Mon, 2 Feb 2015 15:12:47 -0500 Subject: [PATCH 1/9] add get and getAll methods on Data --- python/test/test_images.py | 24 ++++++++++++++++++++++- python/test/test_series.py | 29 +++++++++++++++++++++++++++- python/thunder/rdds/data.py | 38 +++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 2 deletions(-) diff --git a/python/test/test_images.py b/python/test/test_images.py index c5dcf8be..959fa564 100644 --- a/python/test/test_images.py +++ b/python/test/test_images.py @@ -4,7 +4,7 @@ from numpy import allclose, arange, array, array_equal, prod, squeeze, zeros from numpy import dtype as dtypeFunc import itertools -from nose.tools import assert_equals, assert_raises, assert_true +from nose.tools import assert_equals, assert_is_none, assert_raises, assert_true import unittest from thunder.rdds.fileio.imagesloader import ImagesLoader @@ -485,6 +485,28 @@ def test_min(self): assert_true(array_equal(reduce(minimum, arys), minVal)) +class TestImagesGetters(PySparkTestCase): + def setUp(self): + super(TestImagesGetters, self).setUp() + self.ary1 = array([[1, 2], [3, 4]], dtype='int16') + self.ary2 = array([[5, 6], [7, 8]], dtype='int16') + self.images = ImagesLoader(self.sc).fromArrays([self.ary1, self.ary2]) + + def test_getMissing(self): + assert_is_none(self.images.get(-1)) + + def test_get(self): + assert_true(array_equal(self.ary2, self.images.get(1))) + + def test_getAll(self): + vals = self.images.getAll([0, -1, 1, 0]) + assert_equals(4, len(vals)) + assert_true(array_equal(self.ary1, vals[0])) + assert_is_none(vals[1]) + assert_true(array_equal(self.ary2, vals[2])) + assert_true(array_equal(self.ary1, vals[3])) + + class TestImagesUsingOutputDir(PySparkTestCaseWithOutputDir): @staticmethod diff --git a/python/test/test_series.py b/python/test/test_series.py index 2c01ee23..2bb6d666 100644 --- a/python/test/test_series.py +++ b/python/test/test_series.py @@ -1,6 +1,6 @@ from numpy import allclose, amax, arange, array, array_equal from numpy import dtype as dtypeFunc -from nose.tools import assert_equals, assert_true +from nose.tools import assert_equals, assert_is_none, assert_true from thunder.rdds.series import Series from test_utils import * @@ -287,3 +287,30 @@ def test_maxProject(self): assert_true(array_equal(amax(ary.T, 0), project0)) assert_true(array_equal(amax(ary.T, 1), project1)) + + +class TestSeriesGetters(PySparkTestCase): + def setUp(self): + super(TestSeriesGetters, self).setUp() + self.dataLocal = [ + ((0, 0), array([1.0, 2.0, 3.0], dtype='float32')), + ((0, 1), array([2.0, 2.0, 4.0], dtype='float32')), + ((1, 0), array([4.0, 2.0, 1.0], dtype='float32')), + ((1, 1), array([3.0, 1.0, 1.0], dtype='float32')) + ] + self.series = Series(self.sc.parallelize(self.dataLocal), dtype='float32', dims=(2, 2), index=[0, 1, 2]) + + def test_getMissing(self): + assert_is_none(self.series.get(-1)) + + def test_get(self): + expected = self.dataLocal[1][1] + assert_true(array_equal(expected, self.series.get((0, 1)))) + + def test_getAll(self): + vals = self.series.getAll([(0, 0), (17, 256), (1, 0), (0, 0)]) + assert_equals(4, len(vals)) + assert_true(array_equal(self.dataLocal[0][1], vals[0])) + assert_is_none(vals[1]) + assert_true(array_equal(self.dataLocal[2][1], vals[2])) + assert_true(array_equal(self.dataLocal[0][1], vals[3])) \ No newline at end of file diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index ae7da0c3..ee105735 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -111,6 +111,44 @@ def take(self, *args, **kwargs): """ return self.rdd.take(*args, **kwargs) + def get(self, key): + """Returns a single value matching the passed key, or None if no matching keys found + + If multiple records are found with keys matching the passed key, a sequence of all matching + values will be returned. (This is not expected as a normal occurance, but could happen with + some user-created rdds.) + """ + filteredVals = self.rdd.filter(lambda (k, v): k == key).values().collect() + if len(filteredVals) == 1: + return filteredVals[0] + elif not filteredVals: + return None + else: + return filteredVals + + def getAll(self, keys): + """Returns a sequence of values corresponding to the passed sequence of keys. + + The return value will be a sequence equal in length to the passed keys, with each + value in the returned sequence corresponding to the key at the same position in the passed + keys sequence. If no value is found for a given key, the corresponding sequence element will be None. + If multiple values are found, the corresponding sequence element will be a sequence containing all + matching values. + """ + keySet = frozenset(keys) + filteredRecs = self.rdd.filter(lambda (k, _): k in keySet).collect() + sortingDict = {} + for k, v in filteredRecs: + sortingDict.setdefault(k, []).append(v) + retVals = [] + for k in keys: + vals = sortingDict.get(k) + if vals is not None: + if len(vals) == 1: + vals = vals[0] + retVals.append(vals) + return retVals + def values(self): """ Return values, ignoring keys From bc44aa9e1344e8aba9d5ca0c3cd639ce9d44c7f2 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Tue, 3 Feb 2015 12:24:39 -0500 Subject: [PATCH 2/9] add getRange method to Data --- python/test/test_images.py | 29 +++++++++++++++++++ python/test/test_series.py | 47 ++++++++++++++++++++++++++++-- python/thunder/rdds/data.py | 58 +++++++++++++++++++++++++++++++++++++ 3 files changed, 132 insertions(+), 2 deletions(-) diff --git a/python/test/test_images.py b/python/test/test_images.py index 959fa564..c10d2fb5 100644 --- a/python/test/test_images.py +++ b/python/test/test_images.py @@ -506,6 +506,35 @@ def test_getAll(self): assert_true(array_equal(self.ary2, vals[2])) assert_true(array_equal(self.ary1, vals[3])) + def test_getRanges(self): + vals = self.images.getRange(slice(None)) + assert_equals(2, len(vals)) + assert_equals(0, vals[0][0]) + assert_equals(1, vals[1][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + assert_true(array_equal(self.ary2, vals[1][1])) + + vals = self.images.getRange(slice(0, 1)) + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + vals = self.images.getRange(slice(1)) + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + vals = self.images.getRange(slice(1, 2)) + assert_equals(1, len(vals)) + assert_equals(1, vals[0][0]) + assert_true(array_equal(self.ary2, vals[0][1])) + + vals = self.images.getRange(slice(2, 3)) + assert_equals(0, len(vals)) + + # raise exception if 'step' specified: + assert_raises(ValueError, self.images.getRange, slice(1, 2, 2)) + class TestImagesUsingOutputDir(PySparkTestCaseWithOutputDir): diff --git a/python/test/test_series.py b/python/test/test_series.py index 2bb6d666..0cb7f09a 100644 --- a/python/test/test_series.py +++ b/python/test/test_series.py @@ -1,6 +1,6 @@ from numpy import allclose, amax, arange, array, array_equal from numpy import dtype as dtypeFunc -from nose.tools import assert_equals, assert_is_none, assert_true +from nose.tools import assert_equals, assert_is_none, assert_raises, assert_true from thunder.rdds.series import Series from test_utils import * @@ -313,4 +313,47 @@ def test_getAll(self): assert_true(array_equal(self.dataLocal[0][1], vals[0])) assert_is_none(vals[1]) assert_true(array_equal(self.dataLocal[2][1], vals[2])) - assert_true(array_equal(self.dataLocal[0][1], vals[3])) \ No newline at end of file + assert_true(array_equal(self.dataLocal[0][1], vals[3])) + + def test_getRanges(self): + vals = self.series.getRange([slice(2), slice(2)]) + assert_equals(4, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_equals(self.dataLocal[2][0], vals[2][0]) + assert_equals(self.dataLocal[3][0], vals[3][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + + vals = self.series.getRange([slice(2), slice(1)]) + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[2][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[1][1])) + + vals = self.series.getRange([slice(None), slice(1, 2)]) + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_equals(self.dataLocal[3][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[1][1])) + + vals = self.series.getRange([slice(None), slice(None)]) + assert_equals(4, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_equals(self.dataLocal[2][0], vals[2][0]) + assert_equals(self.dataLocal[3][0], vals[3][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + + vals = self.series.getRange([slice(2, 3), slice(None)]) + assert_equals(0, len(vals)) + + # raise exception if 'step' specified: + assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)]) \ No newline at end of file diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index ee105735..8194b2c9 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -149,6 +149,64 @@ def getAll(self, keys): retVals.append(vals) return retVals + def getRange(self, sliceOrSlices): + """Returns key/value pairs that fall within a range given by the passed slice or slices. + + The return values will be a sorted list of key/value pairs of all records in the underlying + RDD for which the key falls within the range given by the passed slice selectors. Note that + this may be very large, and could potentially exhaust the available memory on the driver. + + The cardinality of the passed slice or sequence of slices must match that of the keys of + this RDD's records. For singleton keys, a single slice (or slice sequence of length one) + should be passed. For tuple keys, a sequence of multiple slices (as many as the cardinality + of the keys) should be passed. + + Passed slices should not have a `step` attribute defined; this is not supported and a + ValueError will be raised if a step attribute is passed. + + Parameters + ---------- + sliceOrSlices: slice object or sequence of slices + The passed slice or slices should be of the same cardinality as the keys of the underlying rdd. + + Returns + ------- + sorted sequence of key/value pairs + """ + # None is less than everything except itself + def singleSlicePredicate(kv): + key, _ = kv + if sliceOrSlices.stop is None: + return key >= sliceOrSlices.start + return sliceOrSlices.stop > key >= sliceOrSlices.start + + def multiSlicesPredicate(kv): + key, _ = kv + for slise, subkey in zip(sliceOrSlices, key): + if slise.stop is None: + if subkey < slise.start: + return False + elif not (slise.stop > subkey >= slise.start): + return False + return True + + if not hasattr(sliceOrSlices, '__len__'): + # make my func the... + pFunc = singleSlicePredicate + if sliceOrSlices.step is not None: + raise ValueError("'step' slice attribute is not supported in getRange, got step: %d" % + sliceOrSlices.step) + else: + pFunc = multiSlicesPredicate + for slise in sliceOrSlices: + if slise.step is not None: + raise ValueError("'step' slice attribute is not supported in getRange, got step: %d" % + slise.step) + + filteredRecs = self.rdd.filter(pFunc).collect() + # default sort of tuples is by first item, which happens to be what we want + return sorted(filteredRecs) + def values(self): """ Return values, ignoring keys From 2f82156334652ae665e6743264306ce123ccae01 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Tue, 3 Feb 2015 12:53:44 -0500 Subject: [PATCH 3/9] add brackets getter on Data --- python/test/test_images.py | 35 +++++++++++++++++++++++++++++++++++ python/test/test_series.py | 35 ++++++++++++++++++++++++++++++++++- python/thunder/rdds/data.py | 13 +++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/python/test/test_images.py b/python/test/test_images.py index c10d2fb5..e6c8adcc 100644 --- a/python/test/test_images.py +++ b/python/test/test_images.py @@ -535,6 +535,41 @@ def test_getRanges(self): # raise exception if 'step' specified: assert_raises(ValueError, self.images.getRange, slice(1, 2, 2)) + def test_brackets(self): + vals = self.images[1] + assert_true(array_equal(self.ary2, vals)) + + vals = self.images[0:1] + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + vals = self.images[:] + assert_equals(2, len(vals)) + assert_equals(0, vals[0][0]) + assert_equals(1, vals[1][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + assert_true(array_equal(self.ary2, vals[1][1])) + + vals = self.images[1:4] + assert_equals(1, len(vals)) + assert_equals(1, vals[0][0]) + assert_true(array_equal(self.ary2, vals[0][1])) + + vals = self.images[1:] + assert_equals(1, len(vals)) + assert_equals(1, vals[0][0]) + assert_true(array_equal(self.ary2, vals[0][1])) + + vals = self.images[:1] + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + assert_raises(KeyError, self.images.__getitem__, 2) # equiv: self.images[2] + + assert_raises(IndexError, self.images.__getitem__, slice(2, 3)) # equiv: self.images[2:3] + class TestImagesUsingOutputDir(PySparkTestCaseWithOutputDir): diff --git a/python/test/test_series.py b/python/test/test_series.py index 0cb7f09a..7e532307 100644 --- a/python/test/test_series.py +++ b/python/test/test_series.py @@ -356,4 +356,37 @@ def test_getRanges(self): assert_equals(0, len(vals)) # raise exception if 'step' specified: - assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)]) \ No newline at end of file + assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)]) + + def test_brackets(self): + vals = self.series[(1, 0)] + assert_true(array_equal(self.dataLocal[2][1], vals)) + + vals = self.series[:4, :1] + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[2][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[1][1])) + + vals = self.series[:, 1:2] + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_equals(self.dataLocal[3][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[1][1])) + + vals = self.series[:, :] + assert_equals(4, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_equals(self.dataLocal[2][0], vals[2][0]) + assert_equals(self.dataLocal[3][0], vals[3][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + + assert_raises(KeyError, self.series.__getitem__, (25, 17)) # equiv: self.series[(25, 17)] + + assert_raises(IndexError, self.series.__getitem__, [slice(2, 3), slice(None)]) # series[2:3,:] \ No newline at end of file diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index 8194b2c9..96d4f49d 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -207,6 +207,19 @@ def multiSlicesPredicate(kv): # default sort of tuples is by first item, which happens to be what we want return sorted(filteredRecs) + def __getitem__(self, item): + # should raise exception here when no matching items found + # see object.__getitem__ in https://docs.python.org/2/reference/datamodel.html + if isinstance(item, slice) or (hasattr(item, "__len__") and isinstance(item[0], slice)): + retVals = self.getRange(item) + if not retVals: + raise IndexError("No keys found for slice(s): '%s'" % str(item)) + else: + retVals = self.get(item) + if retVals is None: + raise KeyError("No key found matching '%s'" % str(item)) + return retVals + def values(self): """ Return values, ignoring keys From 00dbadc693728201fc133fc68493ae0786c69447 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Thu, 5 Feb 2015 14:12:47 -0500 Subject: [PATCH 4/9] support mix of slices and indices in bracket get and getRange --- python/test/test_series.py | 39 +++++++++++++++++++++++++++++++++++++ python/thunder/rdds/data.py | 34 ++++++++++++++++++++++---------- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/python/test/test_series.py b/python/test/test_series.py index 7e532307..f442f5b8 100644 --- a/python/test/test_series.py +++ b/python/test/test_series.py @@ -352,6 +352,18 @@ def test_getRanges(self): assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + vals = self.series.getRange([0, slice(None)]) + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + + vals = self.series.getRange([0, 1]) + assert_equals(1, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + vals = self.series.getRange([slice(2, 3), slice(None)]) assert_equals(0, len(vals)) @@ -359,9 +371,21 @@ def test_getRanges(self): assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)]) def test_brackets(self): + # returns just value; calls `get` vals = self.series[(1, 0)] assert_true(array_equal(self.dataLocal[2][1], vals)) + # tuple isn't needed; returns just value, calls `get` + vals = self.series[0, 1] + assert_true(array_equal(self.dataLocal[1][1], vals)) + + # if slices are passed, calls `getRange`, returns keys and values + vals = self.series[0:1, 1:2] + assert_equals(1, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + + # if slice extends out of bounds, return only the elements that are in bounds vals = self.series[:4, :1] assert_equals(2, len(vals)) assert_equals(self.dataLocal[0][0], vals[0][0]) @@ -369,6 +393,7 @@ def test_brackets(self): assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) assert_true(array_equal(self.dataLocal[2][1], vals[1][1])) + # empty slice works vals = self.series[:, 1:2] assert_equals(2, len(vals)) assert_equals(self.dataLocal[1][0], vals[0][0]) @@ -376,6 +401,7 @@ def test_brackets(self): assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) assert_true(array_equal(self.dataLocal[3][1], vals[1][1])) + # multiple empty slices work vals = self.series[:, :] assert_equals(4, len(vals)) assert_equals(self.dataLocal[0][0], vals[0][0]) @@ -387,6 +413,19 @@ def test_brackets(self): assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + # mixing slices and individual indicies works: + vals = self.series[0, :] + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + + # trying to getitem a key that doesn't exist raises KeyError + # this differs from `get` behavior but is consistent with python dict + # see object.__getitem__ in https://docs.python.org/2/reference/datamodel.html assert_raises(KeyError, self.series.__getitem__, (25, 17)) # equiv: self.series[(25, 17)] + # passing a range that is completely out of bounds throws IndexError + # note that if a range is only partly out of bounds, it will return what elements the slice does include assert_raises(IndexError, self.series.__getitem__, [slice(2, 3), slice(None)]) # series[2:3,:] \ No newline at end of file diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index 96d4f49d..28465756 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -176,30 +176,37 @@ def getRange(self, sliceOrSlices): # None is less than everything except itself def singleSlicePredicate(kv): key, _ = kv - if sliceOrSlices.stop is None: - return key >= sliceOrSlices.start - return sliceOrSlices.stop > key >= sliceOrSlices.start + if isinstance(sliceOrSlices, slice): + if sliceOrSlices.stop is None: + return key >= sliceOrSlices.start + return sliceOrSlices.stop > key >= sliceOrSlices.start + else: # apparently this isn't a slice + return key == sliceOrSlices def multiSlicesPredicate(kv): key, _ = kv for slise, subkey in zip(sliceOrSlices, key): - if slise.stop is None: - if subkey < slise.start: + if isinstance(slise, slice): + if slise.stop is None: + if subkey < slise.start: + return False + elif not (slise.stop > subkey >= slise.start): + return False + else: # not a slice + if subkey != slise: return False - elif not (slise.stop > subkey >= slise.start): - return False return True if not hasattr(sliceOrSlices, '__len__'): # make my func the... pFunc = singleSlicePredicate - if sliceOrSlices.step is not None: + if hasattr(sliceOrSlices, 'step') and sliceOrSlices.step is not None: raise ValueError("'step' slice attribute is not supported in getRange, got step: %d" % sliceOrSlices.step) else: pFunc = multiSlicesPredicate for slise in sliceOrSlices: - if slise.step is not None: + if hasattr(slise, 'step') and slise.step is not None: raise ValueError("'step' slice attribute is not supported in getRange, got step: %d" % slise.step) @@ -210,7 +217,14 @@ def multiSlicesPredicate(kv): def __getitem__(self, item): # should raise exception here when no matching items found # see object.__getitem__ in https://docs.python.org/2/reference/datamodel.html - if isinstance(item, slice) or (hasattr(item, "__len__") and isinstance(item[0], slice)): + isRangeQuery = False + if isinstance(item, slice): + isRangeQuery = True + elif hasattr(item, '__iter__'): + if any([isinstance(slise, slice) for slise in item]): + isRangeQuery = True + + if isRangeQuery: retVals = self.getRange(item) if not retVals: raise IndexError("No keys found for slice(s): '%s'" % str(item)) From 61ea4b084499a1a03816008c6cfb6b9ecfb483fc Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Fri, 6 Feb 2015 18:09:55 -0500 Subject: [PATCH 5/9] explicate the pFunc Apparently there are those in the world who do not immediately recognize a George Clinton / Parliament reference. --- python/thunder/rdds/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index 28465756..465e6586 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -198,7 +198,7 @@ def multiSlicesPredicate(kv): return True if not hasattr(sliceOrSlices, '__len__'): - # make my func the... + # make my func the pFunc; http://en.wikipedia.org/wiki/P._Funk_%28Wants_to_Get_Funked_Up%29 pFunc = singleSlicePredicate if hasattr(sliceOrSlices, 'step') and sliceOrSlices.step is not None: raise ValueError("'step' slice attribute is not supported in getRange, got step: %d" % From 5cce6441f51dcf023214d3be60b58253edb46cd6 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Fri, 6 Feb 2015 18:17:47 -0500 Subject: [PATCH 6/9] move get method tests into new test_data.py file --- python/test/test_data.py | 238 +++++++++++++++++++++++++++++++++++++ python/test/test_images.py | 86 -------------- python/test/test_series.py | 142 ---------------------- 3 files changed, 238 insertions(+), 228 deletions(-) create mode 100644 python/test/test_data.py diff --git a/python/test/test_data.py b/python/test/test_data.py new file mode 100644 index 00000000..f72cbf55 --- /dev/null +++ b/python/test/test_data.py @@ -0,0 +1,238 @@ +from nose.tools import assert_equals, assert_is_none, assert_raises, assert_true +from numpy import array, array_equal + +from thunder.rdds.data import Data +from test_utils import PySparkTestCase + + +class TestImagesGetters(PySparkTestCase): + """Test `get` and related methods on an Images-like Data object + """ + def setUp(self): + super(TestImagesGetters, self).setUp() + self.ary1 = array([[1, 2], [3, 4]], dtype='int16') + self.ary2 = array([[5, 6], [7, 8]], dtype='int16') + rdd = self.sc.parallelize([(0, self.ary1), (1, self.ary2)]) + self.images = Data(rdd, dtype='int16') + + def test_getMissing(self): + assert_is_none(self.images.get(-1)) + + def test_get(self): + assert_true(array_equal(self.ary2, self.images.get(1))) + + def test_getAll(self): + vals = self.images.getAll([0, -1, 1, 0]) + assert_equals(4, len(vals)) + assert_true(array_equal(self.ary1, vals[0])) + assert_is_none(vals[1]) + assert_true(array_equal(self.ary2, vals[2])) + assert_true(array_equal(self.ary1, vals[3])) + + def test_getRanges(self): + vals = self.images.getRange(slice(None)) + assert_equals(2, len(vals)) + assert_equals(0, vals[0][0]) + assert_equals(1, vals[1][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + assert_true(array_equal(self.ary2, vals[1][1])) + + vals = self.images.getRange(slice(0, 1)) + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + vals = self.images.getRange(slice(1)) + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + vals = self.images.getRange(slice(1, 2)) + assert_equals(1, len(vals)) + assert_equals(1, vals[0][0]) + assert_true(array_equal(self.ary2, vals[0][1])) + + vals = self.images.getRange(slice(2, 3)) + assert_equals(0, len(vals)) + + # raise exception if 'step' specified: + assert_raises(ValueError, self.images.getRange, slice(1, 2, 2)) + + def test_brackets(self): + vals = self.images[1] + assert_true(array_equal(self.ary2, vals)) + + vals = self.images[0:1] + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + vals = self.images[:] + assert_equals(2, len(vals)) + assert_equals(0, vals[0][0]) + assert_equals(1, vals[1][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + assert_true(array_equal(self.ary2, vals[1][1])) + + vals = self.images[1:4] + assert_equals(1, len(vals)) + assert_equals(1, vals[0][0]) + assert_true(array_equal(self.ary2, vals[0][1])) + + vals = self.images[1:] + assert_equals(1, len(vals)) + assert_equals(1, vals[0][0]) + assert_true(array_equal(self.ary2, vals[0][1])) + + vals = self.images[:1] + assert_equals(1, len(vals)) + assert_equals(0, vals[0][0]) + assert_true(array_equal(self.ary1, vals[0][1])) + + assert_raises(KeyError, self.images.__getitem__, 2) # equiv: self.images[2] + + assert_raises(IndexError, self.images.__getitem__, slice(2, 3)) # equiv: self.images[2:3] + + +class TestSeriesGetters(PySparkTestCase): + """Test `get` and related methods on a Series-like Data object + """ + def setUp(self): + super(TestSeriesGetters, self).setUp() + self.dataLocal = [ + ((0, 0), array([1.0, 2.0, 3.0], dtype='float32')), + ((0, 1), array([2.0, 2.0, 4.0], dtype='float32')), + ((1, 0), array([4.0, 2.0, 1.0], dtype='float32')), + ((1, 1), array([3.0, 1.0, 1.0], dtype='float32')) + ] + self.series = Data(self.sc.parallelize(self.dataLocal), dtype='float32') + + def test_getMissing(self): + assert_is_none(self.series.get(-1)) + + def test_get(self): + expected = self.dataLocal[1][1] + assert_true(array_equal(expected, self.series.get((0, 1)))) + + def test_getAll(self): + vals = self.series.getAll([(0, 0), (17, 256), (1, 0), (0, 0)]) + assert_equals(4, len(vals)) + assert_true(array_equal(self.dataLocal[0][1], vals[0])) + assert_is_none(vals[1]) + assert_true(array_equal(self.dataLocal[2][1], vals[2])) + assert_true(array_equal(self.dataLocal[0][1], vals[3])) + + def test_getRanges(self): + vals = self.series.getRange([slice(2), slice(2)]) + assert_equals(4, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_equals(self.dataLocal[2][0], vals[2][0]) + assert_equals(self.dataLocal[3][0], vals[3][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + + vals = self.series.getRange([slice(2), slice(1)]) + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[2][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[1][1])) + + vals = self.series.getRange([slice(None), slice(1, 2)]) + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_equals(self.dataLocal[3][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[1][1])) + + vals = self.series.getRange([slice(None), slice(None)]) + assert_equals(4, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_equals(self.dataLocal[2][0], vals[2][0]) + assert_equals(self.dataLocal[3][0], vals[3][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + + vals = self.series.getRange([0, slice(None)]) + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + + vals = self.series.getRange([0, 1]) + assert_equals(1, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + + vals = self.series.getRange([slice(2, 3), slice(None)]) + assert_equals(0, len(vals)) + + # raise exception if 'step' specified: + assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)]) + + def test_brackets(self): + # returns just value; calls `get` + vals = self.series[(1, 0)] + assert_true(array_equal(self.dataLocal[2][1], vals)) + + # tuple isn't needed; returns just value, calls `get` + vals = self.series[0, 1] + assert_true(array_equal(self.dataLocal[1][1], vals)) + + # if slices are passed, calls `getRange`, returns keys and values + vals = self.series[0:1, 1:2] + assert_equals(1, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + + # if slice extends out of bounds, return only the elements that are in bounds + vals = self.series[:4, :1] + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[2][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[1][1])) + + # empty slice works + vals = self.series[:, 1:2] + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[1][0], vals[0][0]) + assert_equals(self.dataLocal[3][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[1][1])) + + # multiple empty slices work + vals = self.series[:, :] + assert_equals(4, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_equals(self.dataLocal[2][0], vals[2][0]) + assert_equals(self.dataLocal[3][0], vals[3][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) + assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) + + # mixing slices and individual indicies works: + vals = self.series[0, :] + assert_equals(2, len(vals)) + assert_equals(self.dataLocal[0][0], vals[0][0]) + assert_equals(self.dataLocal[1][0], vals[1][0]) + assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) + assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) + + # trying to getitem a key that doesn't exist raises KeyError + # this differs from `get` behavior but is consistent with python dict + # see object.__getitem__ in https://docs.python.org/2/reference/datamodel.html + assert_raises(KeyError, self.series.__getitem__, (25, 17)) # equiv: self.series[(25, 17)] + + # passing a range that is completely out of bounds throws IndexError + # note that if a range is only partly out of bounds, it will return what elements the slice does include + assert_raises(IndexError, self.series.__getitem__, [slice(2, 3), slice(None)]) # series[2:3,:] diff --git a/python/test/test_images.py b/python/test/test_images.py index e6c8adcc..7da6cece 100644 --- a/python/test/test_images.py +++ b/python/test/test_images.py @@ -485,92 +485,6 @@ def test_min(self): assert_true(array_equal(reduce(minimum, arys), minVal)) -class TestImagesGetters(PySparkTestCase): - def setUp(self): - super(TestImagesGetters, self).setUp() - self.ary1 = array([[1, 2], [3, 4]], dtype='int16') - self.ary2 = array([[5, 6], [7, 8]], dtype='int16') - self.images = ImagesLoader(self.sc).fromArrays([self.ary1, self.ary2]) - - def test_getMissing(self): - assert_is_none(self.images.get(-1)) - - def test_get(self): - assert_true(array_equal(self.ary2, self.images.get(1))) - - def test_getAll(self): - vals = self.images.getAll([0, -1, 1, 0]) - assert_equals(4, len(vals)) - assert_true(array_equal(self.ary1, vals[0])) - assert_is_none(vals[1]) - assert_true(array_equal(self.ary2, vals[2])) - assert_true(array_equal(self.ary1, vals[3])) - - def test_getRanges(self): - vals = self.images.getRange(slice(None)) - assert_equals(2, len(vals)) - assert_equals(0, vals[0][0]) - assert_equals(1, vals[1][0]) - assert_true(array_equal(self.ary1, vals[0][1])) - assert_true(array_equal(self.ary2, vals[1][1])) - - vals = self.images.getRange(slice(0, 1)) - assert_equals(1, len(vals)) - assert_equals(0, vals[0][0]) - assert_true(array_equal(self.ary1, vals[0][1])) - - vals = self.images.getRange(slice(1)) - assert_equals(1, len(vals)) - assert_equals(0, vals[0][0]) - assert_true(array_equal(self.ary1, vals[0][1])) - - vals = self.images.getRange(slice(1, 2)) - assert_equals(1, len(vals)) - assert_equals(1, vals[0][0]) - assert_true(array_equal(self.ary2, vals[0][1])) - - vals = self.images.getRange(slice(2, 3)) - assert_equals(0, len(vals)) - - # raise exception if 'step' specified: - assert_raises(ValueError, self.images.getRange, slice(1, 2, 2)) - - def test_brackets(self): - vals = self.images[1] - assert_true(array_equal(self.ary2, vals)) - - vals = self.images[0:1] - assert_equals(1, len(vals)) - assert_equals(0, vals[0][0]) - assert_true(array_equal(self.ary1, vals[0][1])) - - vals = self.images[:] - assert_equals(2, len(vals)) - assert_equals(0, vals[0][0]) - assert_equals(1, vals[1][0]) - assert_true(array_equal(self.ary1, vals[0][1])) - assert_true(array_equal(self.ary2, vals[1][1])) - - vals = self.images[1:4] - assert_equals(1, len(vals)) - assert_equals(1, vals[0][0]) - assert_true(array_equal(self.ary2, vals[0][1])) - - vals = self.images[1:] - assert_equals(1, len(vals)) - assert_equals(1, vals[0][0]) - assert_true(array_equal(self.ary2, vals[0][1])) - - vals = self.images[:1] - assert_equals(1, len(vals)) - assert_equals(0, vals[0][0]) - assert_true(array_equal(self.ary1, vals[0][1])) - - assert_raises(KeyError, self.images.__getitem__, 2) # equiv: self.images[2] - - assert_raises(IndexError, self.images.__getitem__, slice(2, 3)) # equiv: self.images[2:3] - - class TestImagesUsingOutputDir(PySparkTestCaseWithOutputDir): @staticmethod diff --git a/python/test/test_series.py b/python/test/test_series.py index d16cdd19..b3f1ba2f 100644 --- a/python/test/test_series.py +++ b/python/test/test_series.py @@ -305,145 +305,3 @@ def setIndex(data, idx): assert_raises(ValueError, setIndex, data, 5) assert_raises(ValueError, setIndex, data, [1, 2]) - - -class TestSeriesGetters(PySparkTestCase): - def setUp(self): - super(TestSeriesGetters, self).setUp() - self.dataLocal = [ - ((0, 0), array([1.0, 2.0, 3.0], dtype='float32')), - ((0, 1), array([2.0, 2.0, 4.0], dtype='float32')), - ((1, 0), array([4.0, 2.0, 1.0], dtype='float32')), - ((1, 1), array([3.0, 1.0, 1.0], dtype='float32')) - ] - self.series = Series(self.sc.parallelize(self.dataLocal), dtype='float32', dims=(2, 2), index=[0, 1, 2]) - - def test_getMissing(self): - assert_is_none(self.series.get(-1)) - - def test_get(self): - expected = self.dataLocal[1][1] - assert_true(array_equal(expected, self.series.get((0, 1)))) - - def test_getAll(self): - vals = self.series.getAll([(0, 0), (17, 256), (1, 0), (0, 0)]) - assert_equals(4, len(vals)) - assert_true(array_equal(self.dataLocal[0][1], vals[0])) - assert_is_none(vals[1]) - assert_true(array_equal(self.dataLocal[2][1], vals[2])) - assert_true(array_equal(self.dataLocal[0][1], vals[3])) - - def test_getRanges(self): - vals = self.series.getRange([slice(2), slice(2)]) - assert_equals(4, len(vals)) - assert_equals(self.dataLocal[0][0], vals[0][0]) - assert_equals(self.dataLocal[1][0], vals[1][0]) - assert_equals(self.dataLocal[2][0], vals[2][0]) - assert_equals(self.dataLocal[3][0], vals[3][0]) - assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) - assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) - assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) - - vals = self.series.getRange([slice(2), slice(1)]) - assert_equals(2, len(vals)) - assert_equals(self.dataLocal[0][0], vals[0][0]) - assert_equals(self.dataLocal[2][0], vals[1][0]) - assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[2][1], vals[1][1])) - - vals = self.series.getRange([slice(None), slice(1, 2)]) - assert_equals(2, len(vals)) - assert_equals(self.dataLocal[1][0], vals[0][0]) - assert_equals(self.dataLocal[3][0], vals[1][0]) - assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[3][1], vals[1][1])) - - vals = self.series.getRange([slice(None), slice(None)]) - assert_equals(4, len(vals)) - assert_equals(self.dataLocal[0][0], vals[0][0]) - assert_equals(self.dataLocal[1][0], vals[1][0]) - assert_equals(self.dataLocal[2][0], vals[2][0]) - assert_equals(self.dataLocal[3][0], vals[3][0]) - assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) - assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) - assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) - - vals = self.series.getRange([0, slice(None)]) - assert_equals(2, len(vals)) - assert_equals(self.dataLocal[0][0], vals[0][0]) - assert_equals(self.dataLocal[1][0], vals[1][0]) - assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) - - vals = self.series.getRange([0, 1]) - assert_equals(1, len(vals)) - assert_equals(self.dataLocal[1][0], vals[0][0]) - assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) - - vals = self.series.getRange([slice(2, 3), slice(None)]) - assert_equals(0, len(vals)) - - # raise exception if 'step' specified: - assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)]) - - def test_brackets(self): - # returns just value; calls `get` - vals = self.series[(1, 0)] - assert_true(array_equal(self.dataLocal[2][1], vals)) - - # tuple isn't needed; returns just value, calls `get` - vals = self.series[0, 1] - assert_true(array_equal(self.dataLocal[1][1], vals)) - - # if slices are passed, calls `getRange`, returns keys and values - vals = self.series[0:1, 1:2] - assert_equals(1, len(vals)) - assert_equals(self.dataLocal[1][0], vals[0][0]) - assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) - - # if slice extends out of bounds, return only the elements that are in bounds - vals = self.series[:4, :1] - assert_equals(2, len(vals)) - assert_equals(self.dataLocal[0][0], vals[0][0]) - assert_equals(self.dataLocal[2][0], vals[1][0]) - assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[2][1], vals[1][1])) - - # empty slice works - vals = self.series[:, 1:2] - assert_equals(2, len(vals)) - assert_equals(self.dataLocal[1][0], vals[0][0]) - assert_equals(self.dataLocal[3][0], vals[1][0]) - assert_true(array_equal(self.dataLocal[1][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[3][1], vals[1][1])) - - # multiple empty slices work - vals = self.series[:, :] - assert_equals(4, len(vals)) - assert_equals(self.dataLocal[0][0], vals[0][0]) - assert_equals(self.dataLocal[1][0], vals[1][0]) - assert_equals(self.dataLocal[2][0], vals[2][0]) - assert_equals(self.dataLocal[3][0], vals[3][0]) - assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) - assert_true(array_equal(self.dataLocal[2][1], vals[2][1])) - assert_true(array_equal(self.dataLocal[3][1], vals[3][1])) - - # mixing slices and individual indicies works: - vals = self.series[0, :] - assert_equals(2, len(vals)) - assert_equals(self.dataLocal[0][0], vals[0][0]) - assert_equals(self.dataLocal[1][0], vals[1][0]) - assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) - assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) - - # trying to getitem a key that doesn't exist raises KeyError - # this differs from `get` behavior but is consistent with python dict - # see object.__getitem__ in https://docs.python.org/2/reference/datamodel.html - assert_raises(KeyError, self.series.__getitem__, (25, 17)) # equiv: self.series[(25, 17)] - - # passing a range that is completely out of bounds throws IndexError - # note that if a range is only partly out of bounds, it will return what elements the slice does include - assert_raises(IndexError, self.series.__getitem__, [slice(2, 3), slice(None)]) # series[2:3,:] From 263910260fab1ea55d2b76f769bbc1e55516601e Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Fri, 6 Feb 2015 18:19:16 -0500 Subject: [PATCH 7/9] rename getAll to getMany --- python/test/test_data.py | 8 ++++---- python/thunder/rdds/data.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/test/test_data.py b/python/test/test_data.py index f72cbf55..dba77389 100644 --- a/python/test/test_data.py +++ b/python/test/test_data.py @@ -21,8 +21,8 @@ def test_getMissing(self): def test_get(self): assert_true(array_equal(self.ary2, self.images.get(1))) - def test_getAll(self): - vals = self.images.getAll([0, -1, 1, 0]) + def test_getMany(self): + vals = self.images.getMany([0, -1, 1, 0]) assert_equals(4, len(vals)) assert_true(array_equal(self.ary1, vals[0])) assert_is_none(vals[1]) @@ -114,8 +114,8 @@ def test_get(self): expected = self.dataLocal[1][1] assert_true(array_equal(expected, self.series.get((0, 1)))) - def test_getAll(self): - vals = self.series.getAll([(0, 0), (17, 256), (1, 0), (0, 0)]) + def test_getMany(self): + vals = self.series.getMany([(0, 0), (17, 256), (1, 0), (0, 0)]) assert_equals(4, len(vals)) assert_true(array_equal(self.dataLocal[0][1], vals[0])) assert_is_none(vals[1]) diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index 465e6586..1ff82fc9 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -126,7 +126,7 @@ def get(self, key): else: return filteredVals - def getAll(self, keys): + def getMany(self, keys): """Returns a sequence of values corresponding to the passed sequence of keys. The return value will be a sequence equal in length to the passed keys, with each From d5afd12b4a17c714135c76a85d6d02d0810531b2 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Fri, 6 Feb 2015 18:49:49 -0500 Subject: [PATCH 8/9] add key type checking to Data get methods Check that requested key is of appropriate arity for RDD keys, based on a call to rdd.first() to get example key. --- python/test/test_data.py | 24 +++++++++++++++++++++++- python/thunder/rdds/data.py | 26 +++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/python/test/test_data.py b/python/test/test_data.py index dba77389..6d54deea 100644 --- a/python/test/test_data.py +++ b/python/test/test_data.py @@ -21,6 +21,9 @@ def test_getMissing(self): def test_get(self): assert_true(array_equal(self.ary2, self.images.get(1))) + # keys are integers, ask for sequence + assert_raises(ValueError, self.images.get, (1, 2)) + def test_getMany(self): vals = self.images.getMany([0, -1, 1, 0]) assert_equals(4, len(vals)) @@ -29,6 +32,10 @@ def test_getMany(self): assert_true(array_equal(self.ary2, vals[2])) assert_true(array_equal(self.ary1, vals[3])) + # keys are integers, ask for sequences: + assert_raises(ValueError, self.images.get, [(0, 0)]) + assert_raises(ValueError, self.images.get, [0, (0, 0), 1, 0]) + def test_getRanges(self): vals = self.images.getRange(slice(None)) assert_equals(2, len(vals)) @@ -55,6 +62,9 @@ def test_getRanges(self): vals = self.images.getRange(slice(2, 3)) assert_equals(0, len(vals)) + # keys are integers, ask for sequence + assert_raises(ValueError, self.images.getRange, [slice(1), slice(1)]) + # raise exception if 'step' specified: assert_raises(ValueError, self.images.getRange, slice(1, 2, 2)) @@ -108,12 +118,15 @@ def setUp(self): self.series = Data(self.sc.parallelize(self.dataLocal), dtype='float32') def test_getMissing(self): - assert_is_none(self.series.get(-1)) + assert_is_none(self.series.get((-1, -1))) def test_get(self): expected = self.dataLocal[1][1] assert_true(array_equal(expected, self.series.get((0, 1)))) + assert_raises(ValueError, self.series.get, 1) # keys are sequences, ask for integer + assert_raises(ValueError, self.series.get, (1, 2, 3)) # key length mismatch + def test_getMany(self): vals = self.series.getMany([(0, 0), (17, 256), (1, 0), (0, 0)]) assert_equals(4, len(vals)) @@ -122,6 +135,9 @@ def test_getMany(self): assert_true(array_equal(self.dataLocal[2][1], vals[2])) assert_true(array_equal(self.dataLocal[0][1], vals[3])) + assert_raises(ValueError, self.series.getMany, [1]) # keys are sequences, ask for integer + assert_raises(ValueError, self.series.getMany, [(0, 0), 1, (1, 0), (0, 0)]) # asking for integer again + def test_getRanges(self): vals = self.series.getRange([slice(2), slice(2)]) assert_equals(4, len(vals)) @@ -174,6 +190,12 @@ def test_getRanges(self): vals = self.series.getRange([slice(2, 3), slice(None)]) assert_equals(0, len(vals)) + # keys are sequences, ask for single slice + assert_raises(ValueError, self.series.getRange, slice(2, 3)) + + # ask for wrong number of slices + assert_raises(ValueError, self.series.getRange, [slice(2, 3), slice(2, 3), slice(2, 3)]) + # raise exception if 'step' specified: assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)]) diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index 1ff82fc9..9b13a3a4 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -111,6 +111,23 @@ def take(self, *args, **kwargs): """ return self.rdd.take(*args, **kwargs) + @staticmethod + def __getKeyTypeCheck(actualKey, keySpec): + if hasattr(actualKey, "__iter__"): + try: + specLen = len(keySpec) if hasattr(keySpec, "__len__") else \ + reduce(lambda x, y: x + y, [1 for item in keySpec], initial=0) + if specLen != len(actualKey): + raise ValueError("Length of key specifier '%s' does not match length of first key '%s'" % + (str(keySpec), str(actualKey))) + except TypeError: + raise ValueError("Key specifier '%s' appears not to be a sequence type, but actual keys are " % + str(keySpec) + "sequences (first key: '%s')" % str(actualKey)) + else: + if hasattr(keySpec, "__iter__"): + raise ValueError("Key specifier '%s' appears to be a sequence type, " % str(keySpec) + + "but actual keys are not (first key: '%s')" % str(actualKey)) + def get(self, key): """Returns a single value matching the passed key, or None if no matching keys found @@ -118,6 +135,8 @@ def get(self, key): values will be returned. (This is not expected as a normal occurance, but could happen with some user-created rdds.) """ + firstKey = self.first()[0] + Data.__getKeyTypeCheck(firstKey, key) filteredVals = self.rdd.filter(lambda (k, v): k == key).values().collect() if len(filteredVals) == 1: return filteredVals[0] @@ -135,6 +154,9 @@ def getMany(self, keys): If multiple values are found, the corresponding sequence element will be a sequence containing all matching values. """ + firstKey = self.first()[0] + for key in keys: + Data.__getKeyTypeCheck(firstKey, key) keySet = frozenset(keys) filteredRecs = self.rdd.filter(lambda (k, _): k in keySet).collect() sortingDict = {} @@ -197,7 +219,9 @@ def multiSlicesPredicate(kv): return False return True - if not hasattr(sliceOrSlices, '__len__'): + firstKey = self.first()[0] + Data.__getKeyTypeCheck(firstKey, sliceOrSlices) + if not hasattr(sliceOrSlices, '__iter__'): # make my func the pFunc; http://en.wikipedia.org/wiki/P._Funk_%28Wants_to_Get_Funked_Up%29 pFunc = singleSlicePredicate if hasattr(sliceOrSlices, 'step') and sliceOrSlices.step is not None: From 89998722023a10486d8fecb91c6b2776cb1dfe2e Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Fri, 6 Feb 2015 18:55:56 -0500 Subject: [PATCH 9/9] change Data [] behavior to return None or [] for missing keys Previously threw exceptions. This differs from python dict or sequence behavior but is now consistent with get and getRange. --- python/test/test_data.py | 15 ++++++--------- python/thunder/rdds/data.py | 9 ++------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/python/test/test_data.py b/python/test/test_data.py index 6d54deea..34ff4cdf 100644 --- a/python/test/test_data.py +++ b/python/test/test_data.py @@ -99,9 +99,9 @@ def test_brackets(self): assert_equals(0, vals[0][0]) assert_true(array_equal(self.ary1, vals[0][1])) - assert_raises(KeyError, self.images.__getitem__, 2) # equiv: self.images[2] + assert_is_none(self.images[2]) - assert_raises(IndexError, self.images.__getitem__, slice(2, 3)) # equiv: self.images[2:3] + assert_equals([], self.images[2:3]) class TestSeriesGetters(PySparkTestCase): @@ -250,11 +250,8 @@ def test_brackets(self): assert_true(array_equal(self.dataLocal[0][1], vals[0][1])) assert_true(array_equal(self.dataLocal[1][1], vals[1][1])) - # trying to getitem a key that doesn't exist raises KeyError - # this differs from `get` behavior but is consistent with python dict - # see object.__getitem__ in https://docs.python.org/2/reference/datamodel.html - assert_raises(KeyError, self.series.__getitem__, (25, 17)) # equiv: self.series[(25, 17)] + # trying to getitem a key that doesn't exist returns None + assert_is_none(self.series[(25, 17)]) - # passing a range that is completely out of bounds throws IndexError - # note that if a range is only partly out of bounds, it will return what elements the slice does include - assert_raises(IndexError, self.series.__getitem__, [slice(2, 3), slice(None)]) # series[2:3,:] + # passing a range that is completely out of bounds returns [] + assert_equals([], self.series[2:3, :]) diff --git a/python/thunder/rdds/data.py b/python/thunder/rdds/data.py index 9b13a3a4..d8df148d 100644 --- a/python/thunder/rdds/data.py +++ b/python/thunder/rdds/data.py @@ -249,14 +249,9 @@ def __getitem__(self, item): isRangeQuery = True if isRangeQuery: - retVals = self.getRange(item) - if not retVals: - raise IndexError("No keys found for slice(s): '%s'" % str(item)) + return self.getRange(item) else: - retVals = self.get(item) - if retVals is None: - raise KeyError("No key found matching '%s'" % str(item)) - return retVals + return self.get(item) def values(self): """ Return values, ignoring keys