diff --git a/test/test_images_io.py b/test/test_images_io.py index 815ef2c1..660f5516 100644 --- a/test/test_images_io.py +++ b/test/test_images_io.py @@ -114,6 +114,16 @@ def test_from_tif_multi_planes(eng): assert [x.sum() for x in data.toarray()] == [1140006, 1119161, 1098917] +def test_from_tif_multi_planes_discard_extra(eng): + path = os.path.join(resources, 'multilayer_tif', 'dotdotdot_lzw.tif') + data = fromtif(path, nplanes=2, engine=eng, discard_extra=True) + assert data.shape[0] == 1 + assert data.shape[1] == 2 + with pytest.raises(BaseException) as error_msg: + data = fromtif(path, nplanes=2, engine=eng, discard_extra=False) + assert 'nplanes' in str(error_msg.value) + + def test_from_tif_multi_planes_many(eng): path = os.path.join(resources, 'multilayer_tif', 'dotdotdot_lzw*.tif') data = fromtif(path, nplanes=3, engine=eng) diff --git a/thunder/images/readers.py b/thunder/images/readers.py index e7ad3e57..92e07da1 100644 --- a/thunder/images/readers.py +++ b/thunder/images/readers.py @@ -1,4 +1,5 @@ import itertools +import logging from io import BytesIO from numpy import frombuffer, prod, random, asarray, expand_dims @@ -283,8 +284,8 @@ def frombinary(path, shape=None, dtype=None, ext='bin', start=None, stop=None, r raise ValueError("Last dimension '%d' must be divisible by nplanes '%d'" % (shape[-1], nplanes)) - def getarray(idxAndBuf): - idx, buf = idxAndBuf + def getarray(idx_buffer_filename): + idx, buf, _ = idx_buffer_filename ary = frombuffer(buf, dtype=dtype, count=int(prod(shape))).reshape(shape, order=order) if nplanes is None: yield (idx,), ary @@ -294,17 +295,17 @@ def getarray(idxAndBuf): if shape[-1] % nplanes: npoints += 1 timepoint = 0 - lastPlane = 0 - curPlane = 1 - while curPlane < ary.shape[-1]: - if curPlane % nplanes == 0: - slices = [slice(None)] * (ary.ndim - 1) + [slice(lastPlane, curPlane)] + last_plane = 0 + current_plane = 1 + while current_plane < ary.shape[-1]: + if current_plane % nplanes == 0: + slices = [slice(None)] * (ary.ndim - 1) + [slice(last_plane, current_plane)] yield idx*npoints + timepoint, ary[slices].squeeze() timepoint += 1 - lastPlane = curPlane - curPlane += 1 + last_plane = current_plane + current_plane += 1 # yield remaining planes - slices = [slice(None)] * (ary.ndim - 1) + [slice(lastPlane, ary.shape[-1])] + slices = [slice(None)] * (ary.ndim - 1) + [slice(last_plane, ary.shape[-1])] yield (idx*npoints + timepoint,), ary[slices].squeeze() recount = False if nplanes is None else True @@ -315,7 +316,7 @@ def getarray(idxAndBuf): dims=newdims, dtype=dtype, labels=labels, recount=recount, engine=engine, credentials=credentials) -def fromtif(path, ext='tif', start=None, stop=None, recursive=False, nplanes=None, npartitions=None, labels=None, engine=None, credentials=None): +def fromtif(path, ext='tif', start=None, stop=None, recursive=False, nplanes=None, npartitions=None, labels=None, engine=None, credentials=None, discard_extra=False): """ Loads images from single or multi-page TIF files. @@ -346,20 +347,31 @@ def fromtif(path, ext='tif', start=None, stop=None, recursive=False, nplanes=Non labels : array, optional, default = None Labels for records. If provided, should be one-dimensional. + + discard_extra : boolean, optional, default = False + If True and nplanes doesn't divide by the number of pages in a multi-page tiff, the reminder will + be discarded and a warning will be shown. If False, it will raise an error """ import skimage.external.tifffile as tifffile if nplanes is not None and nplanes <= 0: raise ValueError('nplanes must be positive if passed, got %d' % nplanes) - def getarray(idxAndBuf): - idx, buf = idxAndBuf + def getarray(idx_buffer_filename): + idx, buf, fname = idx_buffer_filename fbuf = BytesIO(buf) tfh = tifffile.TiffFile(fbuf) ary = tfh.asarray() pageCount = ary.shape[0] if nplanes is not None: - values = [ary[i:(i+nplanes)] for i in range(0, ary.shape[0], nplanes)] + extra = pageCount % nplanes + if extra: + if discard_extra: + pageCount = pageCount - extra + logging.getLogger('thunder').warn('Ignored %d pages in file %s' % (extra, fname)) + else: + raise ValueError("nplanes '%d' does not evenly divide '%d'" % (nplanes, pageCount)) + values = [ary[i:(i+nplanes)] for i in range(0, pageCount, nplanes)] else: values = [ary] tfh.close() @@ -367,8 +379,6 @@ def getarray(idxAndBuf): if ary.ndim == 3: values = [val.squeeze() for val in values] - if nplanes and (pageCount % nplanes): - raise ValueError("nplanes '%d' does not evenly divide '%d'" % (nplanes, pageCount)) nvals = len(values) keys = [(idx*nvals + timepoint,) for timepoint in range(nvals)] return zip(keys, values) @@ -408,8 +418,8 @@ def frompng(path, ext='png', start=None, stop=None, recursive=False, npartitions """ from scipy.misc import imread - def getarray(idxAndBuf): - idx, buf = idxAndBuf + def getarray(idx_buffer_filename): + idx, buf, _ = idx_buffer_filename fbuf = BytesIO(buf) yield (idx,), imread(fbuf) diff --git a/thunder/readers.py b/thunder/readers.py index 552f008e..8034c7ab 100644 --- a/thunder/readers.py +++ b/thunder/readers.py @@ -149,9 +149,9 @@ def read(self, path, ext=None, start=None, stop=None, recursive=False, npartitio if spark and isinstance(self.engine, spark): npartitions = min(npartitions, nfiles) if npartitions else nfiles rdd = self.engine.parallelize(enumerate(files), npartitions) - return rdd.map(lambda kv: (kv[0], readlocal(kv[1]))) + return rdd.map(lambda kv: (kv[0], readlocal(kv[1]), kv[1])) else: - return [(k, readlocal(v)) for k, v in enumerate(files)] + return [(k, readlocal(v), v) for k, v in enumerate(files)] class LocalFileReader(object):