Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pagination bug for H5Image #218

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions h5pyd/_hl/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def is_hdf5(domain, **kwargs):


class H5Image(io.RawIOBase):
""" file-like-object class that treats bytes of an HSDS dataset as an HDF5 file image
Can be used as a subsitute for a file path in h5py.File(filepath). E.g.:
f = h5py.File(H5Image("hdf5:/myhsds_domain")) """

def __init__(self, domain_path, h5path="h5image", logger=None):
""" verify dataset can be accessed and set logger if supplied """
self._cursor = 0
if domain_path.startswith("hdf5::/"):
self._domain_path = domain_path
Expand All @@ -62,6 +67,7 @@ def __init__(self, domain_path, h5path="h5image", logger=None):
self._logger.info(f"domain {self._domain_path} opened")

def __repr__(self):
""" Just rturn the domain path"""
return f'<{self._domain_path}>'

def readable(self):
Expand All @@ -82,6 +88,7 @@ def tell(self):
return self._cursor

def seek(self, offset, whence=io.SEEK_SET):
""" set the seek pointer """
if whence == io.SEEK_SET:
if self._logger:
self._logger.debug(f"SEEK_SET({offset})")
Expand All @@ -101,6 +108,8 @@ def seek(self, offset, whence=io.SEEK_SET):
return self._cursor

def _get_page(self, page_number):
""" Return bytes for the given page.
Read a page from the HSDS dataset if not already in the cache """
if self._page_cache[page_number] is None:
if self._logger:
self._logger.info(f"reading page {page_number} from server")
Expand All @@ -112,6 +121,7 @@ def _get_page(self, page_number):
return self._page_cache[page_number]

def read(self, size=-1):
""" Read size bytes from the cursor """
start = self._cursor
if size < 0 or self._cursor + size >= self.size:
stop = self.size
Expand All @@ -125,17 +135,22 @@ def read(self, size=-1):

buffer = bytearray(stop - start)
offset = start
while offset - start < size:
while offset < stop:
page_number = offset // self.page_size
page_bytes = self._get_page(page_number)
num_bytes = ((offset + 1) + self.page_size // self.page_size) + self.page_size
if offset + num_bytes - start > size:
num_bytes = start + size - offset
page_start = offset % self.page_size
page_stop = page_start + num_bytes
n = offset % self.page_size
if stop // self.page_size > page_number:
# just read to the end of the page
m = self.page_size
else:
# remaing bytes don't cross page boundry
m = n + (stop - offset)

num_bytes = m - n
buffer_start = offset - start
buffer_stop = buffer_start + num_bytes
buffer[buffer_start:buffer_stop] = page_bytes[page_start:page_stop]
buffer[buffer_start:buffer_stop] = page_bytes[n:m]

offset += num_bytes

if self._logger:
Expand Down