Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for field assignments #227

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 96 additions & 31 deletions h5pyd/_hl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,36 +297,35 @@ def make_new_dset(
return dset_id


class AstypeWrapper(object):
"""Wrapper to convert data on reading from a dataset."""

class AstypeWrapper:
"""Wrapper to convert data on reading from a dataset.
"""
def __init__(self, dset, dtype):
self._dset = dset
self._dtype = numpy.dtype(dtype)

def __getitem__(self, args):
return self._dset.__getitem__(args, new_dtype=self._dtype)

def __enter__(self):
# pylint: disable=protected-access
print(
"Using astype() as a context manager is deprecated. "
"Slice the returned object instead, like: ds.astype(np.int32)[:10]"
)
self._dset._local.astype = self._dtype
return self

def __exit__(self, *args):
# pylint: disable=protected-access
self._dset._local.astype = None

def __len__(self):
"""Get the length of the underlying dataset
""" Get the length of the underlying dataset

>>> length = len(dataset.astype('f8'))
"""
return len(self._dset)

def __array__(self, dtype=None, copy=True):
if copy is False:
raise ValueError(
f"AstypeWrapper.__array__ received {copy=} "
f"but memory allocation cannot be avoided on read"
)

data = self[:]
if dtype is not None:
return data.astype(dtype, copy=False)
return data


class AsStrWrapper:
"""Wrapper to decode strings on reading the dataset"""
Expand Down Expand Up @@ -361,6 +360,43 @@ def __len__(self):
return len(self._dset)


class FieldsWrapper:
"""Wrapper to extract named fields from a dataset with a struct dtype"""
extract_field = None

def __init__(self, dset, prior_dtype, names):
self._dset = dset
if isinstance(names, str):
self.extract_field = names
names = [names]
self.read_dtype = readtime_dtype(prior_dtype, names)

def __array__(self, dtype=None, copy=True):
if copy is False:
raise ValueError(
f"FieldsWrapper.__array__ received {copy=} "
f"but memory allocation cannot be avoided on read"
)
data = self[:]
if dtype is not None:
return data.astype(dtype, copy=False)
else:
return data

def __getitem__(self, args):
data = self._dset.__getitem__(args, new_dtype=self.read_dtype)
if self.extract_field is not None:
data = data[self.extract_field]
return data

def __len__(self):
""" Get the length of the underlying dataset

>>> length = len(dataset.fields(['x', 'y']))
"""
return len(self._dset)


class ChunkIterator(object):
"""
Class to iterate through list of chunks of a given dataset
Expand Down Expand Up @@ -486,6 +522,19 @@ def asstr(self, encoding=None, errors="strict"):

return AsStrWrapper(self, encoding, errors=errors)

def fields(self, names, *, _prior_dtype=None):
"""Get a wrapper to read a subset of fields from a compound data type:

>>> 2d_coords = dataset.fields(['x', 'y'])[:]

If names is a string, a single field is extracted, and the resulting
arrays will have that dtype. Otherwise, it should be an iterable,
and the read data will have a compound dtype.
"""
if _prior_dtype is None:
_prior_dtype = self.dtype
return FieldsWrapper(self, _prior_dtype, names)

@property
def dims(self):
from .dims import DimensionManager
Expand Down Expand Up @@ -890,7 +939,7 @@ def __getitem__(self, args, new_dtype=None):
* Boolean "mask" array indexing
"""
if new_dtype is not None:
self.log.warning("new_dtype is not supported")
self.log.debug(f"getitem.new_dtype: {new_dtype}")
args = args if isinstance(args, tuple) else (args,)
self.log.debug("dataset.__getitem__")
for arg in args:
Expand All @@ -906,8 +955,20 @@ def __getitem__(self, args, new_dtype=None):

# Sort field indices from the rest of the args.
names = tuple(x for x in args if isinstance(x, str))
args = tuple(x for x in args if not isinstance(x, str))
if names:
self.log.debug(f"names: {names}")
# Read a subset of the fields in this structured dtype
if len(names) == 1:
names = names[0] # Read with simpler dtype of this field
args = tuple(x for x in args if not isinstance(x, str))
return self.fields(names, _prior_dtype=new_dtype)[args]

if new_dtype is None:
new_dtype = self.dtype
else:
self.log.debug(f"new_dtype: {new_dtype}")

"""
new_dtype = getattr(self._local, "astype", None)
if new_dtype is not None:
new_dtype = readtime_dtype(new_dtype, names)
Expand All @@ -916,6 +977,7 @@ def __getitem__(self, args, new_dtype=None):
# discards the array information at the top level.
new_dtype = readtime_dtype(self.dtype, names)
self.log.debug(f"new_dtype: {new_dtype}")
"""
if new_dtype.kind == "S" and check_dtype(ref=self.dtype):
new_dtype = special_dtype(ref=Reference)

Expand Down Expand Up @@ -1015,14 +1077,14 @@ def __getitem__(self, args, new_dtype=None):

self.log.debug(f"dataset shape: {self._shape}")
self.log.debug(f"mshape: {mshape}")
self.log.debug(f"single_element: {single_element}")

# Perfom the actual read
rsp = None
req = "/datasets/" + self.id.uuid + "/value"
params = {}

if len(names) > 0:
params["fields"] = ":".join(names)
if mtype.names != self.dtype.names:
params["fields"] = ":".join(mtype.names)

if self.id._http_conn.mode == "r" and self.id._http_conn.cache_on:
# enables lambda to be used on server
Expand Down Expand Up @@ -1152,7 +1214,6 @@ def __getitem__(self, args, new_dtype=None):
# got binary response
# TBD - check expected number of bytes
self.log.info(f"binary response, {len(rsp)} bytes")
# arr1d = numpy.frombuffer(rsp, dtype=mtype)
arr1d = bytesToArray(rsp, mtype, page_mshape)
page_arr = numpy.reshape(arr1d, page_mshape)
else:
Expand Down Expand Up @@ -1328,7 +1389,7 @@ def __setitem__(self, args, val):

# get the val dtype if we're passed a numpy array
try:
msg = f"val dtype: {val.dtype}, shape: {val.shape} metadata: {val.dtype.metadata}"
msg = f"val dtype: {val.dtype}, shape: {val.shape} kind: {val.dtype.kind} metadata: {val.dtype.metadata}"
self.log.debug(msg)
if numpy.prod(val.shape) == 0:
self.log.info("no elements in numpy array, skipping write")
Expand Down Expand Up @@ -1360,6 +1421,7 @@ def __setitem__(self, args, val):
# For h5pyd, do extra check and convert type on client side for efficiency
vlen_base_class = check_dtype(vlen=self.dtype)
if vlen_base_class is not None and vlen_base_class not in (bytes, str):
self.log.debug(f"asarray to base_class: {vlen_base_class}")
try:
# Attempt to directly convert the input array of vlen data to its base class
val = numpy.asarray(val, dtype=vlen_base_class)
Expand Down Expand Up @@ -1417,6 +1479,7 @@ def __setitem__(self, args, val):
# TBD: Do we need something like the following in the above if condition:
# (self.dtype.str != val.dtype.str)
# for cases where the val is a numpy array but different type than self?

if len(names) == 1 and self.dtype.fields is not None:
# Single field selected for write, from a non-array source
if not names[0] in self.dtype.fields:
Expand All @@ -1427,9 +1490,12 @@ def __setitem__(self, args, val):
dtype = self.dtype
cast_compound = False

val = numpy.asarray(val, dtype=dtype, order="C")
self.log.debug(f"asarray dtype: {dtype}, cast_compound: {cast_compound}")
val = numpy.asarray(val, dtype=dtype.base, order="C")
if cast_compound:
val = val.astype(numpy.dtype([(names[0], dtype)]))
# val = val.astype(numpy.dtype([(names[0], dtype)]))
val = val.view(numpy.dtype([(names[0], dtype)]))
val = val.reshape(val.shape[:len(val.shape) - len(dtype.shape)])

elif isinstance(val, numpy.ndarray):
# convert array if needed
Expand All @@ -1447,17 +1513,16 @@ def __setitem__(self, args, val):

# Check for array dtype compatibility and convert
mshape = None
"""
# TBD..
self.log.debug(f"self.dtype.subdtype: {self.dtype.subdtype}")
if self.dtype.subdtype is not None:
shp = self.dtype.subdtype[1] # type shape
valshp = val.shape[-len(shp):]
if valshp != shp: # Last dimension has to match
raise TypeError(f"When writing to array types,\
last N dimensions have to match (got {valshp}, but should be {shp})")
mtype = h5t.py_create(numpy.dtype((val.dtype, shp)))
mshape = val.shape[0:len(val.shape)-len(shp)]
"""
mtype = numpy.dtype((val.dtype, shp))
self.log.debug(f"mtype for subdtype: {mtype}")
mshape = val.shape[0:len(val.shape) - len(shp)]

# Check for field selection
if len(names) != 0:
Expand Down
11 changes: 8 additions & 3 deletions test/hl/test_complex_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ def test_complex_dset(self):
val = dset[0]

self.assertEqual(val.shape, ())
self.assertEqual(val.dtype.kind, 'c')
self.assertEqual(val.real, 1.0)
self.assertEqual(val.imag, 0.)
if config.get('use_h5py'):
self.assertEqual(val.dtype.kind, 'c')
self.assertEqual(val.real, 1.0)
self.assertEqual(val.imag, 0.)
else:
self.assertEqual(val.dtype.kind, 'V')
self.assertEqual(val['r'], 1.0)
self.assertEqual(val['i'], 0.)

def test_complex_attr(self):
"""Read and wrtie complex numbers in attributes"""
Expand Down
12 changes: 0 additions & 12 deletions test/hl/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,13 +1396,7 @@ def test_rt(self):
self.assertTrue(np.all(outdata == testdata))
self.assertEqual(outdata.dtype, testdata.dtype)

@ut.expectedFailure
def test_assign(self):
# Expected failure on HSDS; skip with h5py
if config.get('use_h5py'):
self.assertTrue(False)

# TBD: field assignment not working
dt = np.dtype([('weight', (np.float64, 3)),
('endpoint_type', np.uint8), ])

Expand All @@ -1419,13 +1413,7 @@ def test_assign(self):
self.assertTrue(np.all(outdata == testdata))
self.assertEqual(outdata.dtype, testdata.dtype)

@ut.expectedFailure
def test_fields(self):
# Expected failure on HSDS; skip with h5py
if config.get('use_h5py'):
self.assertTrue(False)

# TBD: field assignment not working
dt = np.dtype([
('x', np.float64),
('y', np.float64),
Expand Down
1 change: 0 additions & 1 deletion test/hl/test_dataset_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_multi_read_scalar_dataspaces(self):
"""
filename = self.getFileName("multi_read_scalar_dataspaces")
print("filename:", filename)
print(f"numpy version: {np.version.version}")
f = h5py.File(filename, 'w')
shape = ()
count = 3
Expand Down
4 changes: 1 addition & 3 deletions test/hl/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def test_read(self):
np.testing.assert_array_equal(outdata, testdata[key])
self.assertEqual(outdata.dtype, testdata[key].dtype)

"""
TBD
@ut.expectedFailure
def test_nested_compound_vlen(self):
dt_inner = np.dtype([('a', h5py.vlen_dtype(np.int32)),
('b', h5py.vlen_dtype(np.int32))])
Expand All @@ -147,7 +146,6 @@ def test_nested_compound_vlen(self):
# Specifying check_alignment=False because vlen fields have 8 bytes of padding
# because the vlen datatype in hdf5 occupies 16 bytes
self.assertArrayEqual(out, data, check_alignment=False)
"""


if __name__ == '__main__':
Expand Down
2 changes: 0 additions & 2 deletions test/hl/test_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def test_list(self):
# Folders not supported for h5py
return

# loglevel = logging.DEBUG
# logging.basicConfig( format='%(asctime)s %(message)s', level=loglevel)
test_domain = self.getFileName("folder_test")

filepath = self.getPathFromDomain(test_domain)
Expand Down