Skip to content

Commit

Permalink
add support for field assignments (#227)
Browse files Browse the repository at this point in the history
* add supprot for field assignments

* fix flake8 error

* remove debug log statement
  • Loading branch information
jreadey authored Oct 25, 2024
1 parent f28b949 commit c30fb5c
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 52 deletions.
127 changes: 96 additions & 31 deletions h5pyd/_hl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,36 +297,35 @@ def make_new_dset(
return dset_id


class AstypeWrapper(object):
"""Wrapper to convert data on reading from a dataset."""

class AstypeWrapper:
"""Wrapper to convert data on reading from a dataset.
"""
def __init__(self, dset, dtype):
self._dset = dset
self._dtype = numpy.dtype(dtype)

def __getitem__(self, args):
return self._dset.__getitem__(args, new_dtype=self._dtype)

def __enter__(self):
# pylint: disable=protected-access
print(
"Using astype() as a context manager is deprecated. "
"Slice the returned object instead, like: ds.astype(np.int32)[:10]"
)
self._dset._local.astype = self._dtype
return self

def __exit__(self, *args):
# pylint: disable=protected-access
self._dset._local.astype = None

def __len__(self):
"""Get the length of the underlying dataset
""" Get the length of the underlying dataset
>>> length = len(dataset.astype('f8'))
"""
return len(self._dset)

def __array__(self, dtype=None, copy=True):
if copy is False:
raise ValueError(
f"AstypeWrapper.__array__ received {copy=} "
f"but memory allocation cannot be avoided on read"
)

data = self[:]
if dtype is not None:
return data.astype(dtype, copy=False)
return data


class AsStrWrapper:
"""Wrapper to decode strings on reading the dataset"""
Expand Down Expand Up @@ -361,6 +360,43 @@ def __len__(self):
return len(self._dset)


class FieldsWrapper:
"""Wrapper to extract named fields from a dataset with a struct dtype"""
extract_field = None

def __init__(self, dset, prior_dtype, names):
self._dset = dset
if isinstance(names, str):
self.extract_field = names
names = [names]
self.read_dtype = readtime_dtype(prior_dtype, names)

def __array__(self, dtype=None, copy=True):
if copy is False:
raise ValueError(
f"FieldsWrapper.__array__ received {copy=} "
f"but memory allocation cannot be avoided on read"
)
data = self[:]
if dtype is not None:
return data.astype(dtype, copy=False)
else:
return data

def __getitem__(self, args):
data = self._dset.__getitem__(args, new_dtype=self.read_dtype)
if self.extract_field is not None:
data = data[self.extract_field]
return data

def __len__(self):
""" Get the length of the underlying dataset
>>> length = len(dataset.fields(['x', 'y']))
"""
return len(self._dset)


class ChunkIterator(object):
"""
Class to iterate through list of chunks of a given dataset
Expand Down Expand Up @@ -486,6 +522,19 @@ def asstr(self, encoding=None, errors="strict"):

return AsStrWrapper(self, encoding, errors=errors)

def fields(self, names, *, _prior_dtype=None):
"""Get a wrapper to read a subset of fields from a compound data type:
>>> 2d_coords = dataset.fields(['x', 'y'])[:]
If names is a string, a single field is extracted, and the resulting
arrays will have that dtype. Otherwise, it should be an iterable,
and the read data will have a compound dtype.
"""
if _prior_dtype is None:
_prior_dtype = self.dtype
return FieldsWrapper(self, _prior_dtype, names)

@property
def dims(self):
from .dims import DimensionManager
Expand Down Expand Up @@ -890,7 +939,7 @@ def __getitem__(self, args, new_dtype=None):
* Boolean "mask" array indexing
"""
if new_dtype is not None:
self.log.warning("new_dtype is not supported")
self.log.debug(f"getitem.new_dtype: {new_dtype}")
args = args if isinstance(args, tuple) else (args,)
self.log.debug("dataset.__getitem__")
for arg in args:
Expand All @@ -906,8 +955,20 @@ def __getitem__(self, args, new_dtype=None):

# Sort field indices from the rest of the args.
names = tuple(x for x in args if isinstance(x, str))
args = tuple(x for x in args if not isinstance(x, str))
if names:
self.log.debug(f"names: {names}")
# Read a subset of the fields in this structured dtype
if len(names) == 1:
names = names[0] # Read with simpler dtype of this field
args = tuple(x for x in args if not isinstance(x, str))
return self.fields(names, _prior_dtype=new_dtype)[args]

if new_dtype is None:
new_dtype = self.dtype
else:
self.log.debug(f"new_dtype: {new_dtype}")

"""
new_dtype = getattr(self._local, "astype", None)
if new_dtype is not None:
new_dtype = readtime_dtype(new_dtype, names)
Expand All @@ -916,6 +977,7 @@ def __getitem__(self, args, new_dtype=None):
# discards the array information at the top level.
new_dtype = readtime_dtype(self.dtype, names)
self.log.debug(f"new_dtype: {new_dtype}")
"""
if new_dtype.kind == "S" and check_dtype(ref=self.dtype):
new_dtype = special_dtype(ref=Reference)

Expand Down Expand Up @@ -1015,14 +1077,14 @@ def __getitem__(self, args, new_dtype=None):

self.log.debug(f"dataset shape: {self._shape}")
self.log.debug(f"mshape: {mshape}")
self.log.debug(f"single_element: {single_element}")

# Perfom the actual read
rsp = None
req = "/datasets/" + self.id.uuid + "/value"
params = {}

if len(names) > 0:
params["fields"] = ":".join(names)
if mtype.names != self.dtype.names:
params["fields"] = ":".join(mtype.names)

if self.id._http_conn.mode == "r" and self.id._http_conn.cache_on:
# enables lambda to be used on server
Expand Down Expand Up @@ -1152,7 +1214,6 @@ def __getitem__(self, args, new_dtype=None):
# got binary response
# TBD - check expected number of bytes
self.log.info(f"binary response, {len(rsp)} bytes")
# arr1d = numpy.frombuffer(rsp, dtype=mtype)
arr1d = bytesToArray(rsp, mtype, page_mshape)
page_arr = numpy.reshape(arr1d, page_mshape)
else:
Expand Down Expand Up @@ -1328,7 +1389,7 @@ def __setitem__(self, args, val):

# get the val dtype if we're passed a numpy array
try:
msg = f"val dtype: {val.dtype}, shape: {val.shape} metadata: {val.dtype.metadata}"
msg = f"val dtype: {val.dtype}, shape: {val.shape} kind: {val.dtype.kind} metadata: {val.dtype.metadata}"
self.log.debug(msg)
if numpy.prod(val.shape) == 0:
self.log.info("no elements in numpy array, skipping write")
Expand Down Expand Up @@ -1360,6 +1421,7 @@ def __setitem__(self, args, val):
# For h5pyd, do extra check and convert type on client side for efficiency
vlen_base_class = check_dtype(vlen=self.dtype)
if vlen_base_class is not None and vlen_base_class not in (bytes, str):
self.log.debug(f"asarray to base_class: {vlen_base_class}")
try:
# Attempt to directly convert the input array of vlen data to its base class
val = numpy.asarray(val, dtype=vlen_base_class)
Expand Down Expand Up @@ -1417,6 +1479,7 @@ def __setitem__(self, args, val):
# TBD: Do we need something like the following in the above if condition:
# (self.dtype.str != val.dtype.str)
# for cases where the val is a numpy array but different type than self?

if len(names) == 1 and self.dtype.fields is not None:
# Single field selected for write, from a non-array source
if not names[0] in self.dtype.fields:
Expand All @@ -1427,9 +1490,12 @@ def __setitem__(self, args, val):
dtype = self.dtype
cast_compound = False

val = numpy.asarray(val, dtype=dtype, order="C")
self.log.debug(f"asarray dtype: {dtype}, cast_compound: {cast_compound}")
val = numpy.asarray(val, dtype=dtype.base, order="C")
if cast_compound:
val = val.astype(numpy.dtype([(names[0], dtype)]))
# val = val.astype(numpy.dtype([(names[0], dtype)]))
val = val.view(numpy.dtype([(names[0], dtype)]))
val = val.reshape(val.shape[:len(val.shape) - len(dtype.shape)])

elif isinstance(val, numpy.ndarray):
# convert array if needed
Expand All @@ -1447,17 +1513,16 @@ def __setitem__(self, args, val):

# Check for array dtype compatibility and convert
mshape = None
"""
# TBD..
self.log.debug(f"self.dtype.subdtype: {self.dtype.subdtype}")
if self.dtype.subdtype is not None:
shp = self.dtype.subdtype[1] # type shape
valshp = val.shape[-len(shp):]
if valshp != shp: # Last dimension has to match
raise TypeError(f"When writing to array types,\
last N dimensions have to match (got {valshp}, but should be {shp})")
mtype = h5t.py_create(numpy.dtype((val.dtype, shp)))
mshape = val.shape[0:len(val.shape)-len(shp)]
"""
mtype = numpy.dtype((val.dtype, shp))
self.log.debug(f"mtype for subdtype: {mtype}")
mshape = val.shape[0:len(val.shape) - len(shp)]

# Check for field selection
if len(names) != 0:
Expand Down
11 changes: 8 additions & 3 deletions test/hl/test_complex_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ def test_complex_dset(self):
val = dset[0]

self.assertEqual(val.shape, ())
self.assertEqual(val.dtype.kind, 'c')
self.assertEqual(val.real, 1.0)
self.assertEqual(val.imag, 0.)
if config.get('use_h5py'):
self.assertEqual(val.dtype.kind, 'c')
self.assertEqual(val.real, 1.0)
self.assertEqual(val.imag, 0.)
else:
self.assertEqual(val.dtype.kind, 'V')
self.assertEqual(val['r'], 1.0)
self.assertEqual(val['i'], 0.)

def test_complex_attr(self):
"""Read and wrtie complex numbers in attributes"""
Expand Down
12 changes: 0 additions & 12 deletions test/hl/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,13 +1396,7 @@ def test_rt(self):
self.assertTrue(np.all(outdata == testdata))
self.assertEqual(outdata.dtype, testdata.dtype)

@ut.expectedFailure
def test_assign(self):
# Expected failure on HSDS; skip with h5py
if config.get('use_h5py'):
self.assertTrue(False)

# TBD: field assignment not working
dt = np.dtype([('weight', (np.float64, 3)),
('endpoint_type', np.uint8), ])

Expand All @@ -1419,13 +1413,7 @@ def test_assign(self):
self.assertTrue(np.all(outdata == testdata))
self.assertEqual(outdata.dtype, testdata.dtype)

@ut.expectedFailure
def test_fields(self):
# Expected failure on HSDS; skip with h5py
if config.get('use_h5py'):
self.assertTrue(False)

# TBD: field assignment not working
dt = np.dtype([
('x', np.float64),
('y', np.float64),
Expand Down
1 change: 0 additions & 1 deletion test/hl/test_dataset_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_multi_read_scalar_dataspaces(self):
"""
filename = self.getFileName("multi_read_scalar_dataspaces")
print("filename:", filename)
print(f"numpy version: {np.version.version}")
f = h5py.File(filename, 'w')
shape = ()
count = 3
Expand Down
4 changes: 1 addition & 3 deletions test/hl/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def test_read(self):
np.testing.assert_array_equal(outdata, testdata[key])
self.assertEqual(outdata.dtype, testdata[key].dtype)

"""
TBD
@ut.expectedFailure
def test_nested_compound_vlen(self):
dt_inner = np.dtype([('a', h5py.vlen_dtype(np.int32)),
('b', h5py.vlen_dtype(np.int32))])
Expand All @@ -147,7 +146,6 @@ def test_nested_compound_vlen(self):
# Specifying check_alignment=False because vlen fields have 8 bytes of padding
# because the vlen datatype in hdf5 occupies 16 bytes
self.assertArrayEqual(out, data, check_alignment=False)
"""


if __name__ == '__main__':
Expand Down
2 changes: 0 additions & 2 deletions test/hl/test_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def test_list(self):
# Folders not supported for h5py
return

# loglevel = logging.DEBUG
# logging.basicConfig( format='%(asctime)s %(message)s', level=loglevel)
test_domain = self.getFileName("folder_test")

filepath = self.getPathFromDomain(test_domain)
Expand Down

0 comments on commit c30fb5c

Please sign in to comment.