Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plate labels fix #207

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 65 additions & 62 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.specs.append(PlateLabels(self))
elif Plate.matches(zarr):
self.specs.append(Plate(self))
# self.add(zarr, plate_labels=True)
self.add(zarr, plate_labels=True)
if Well.matches(zarr):
self.specs.append(Well(self))

Expand Down Expand Up @@ -465,18 +465,13 @@ def matches(zarr: ZarrLocation) -> bool:
def __init__(self, node: Node) -> None:
super().__init__(node)
LOGGER.debug(f"Plate created with ZarrLocation fmt:{ self.zarr.fmt}")
self.get_pyramid_lazy(node)

def get_pyramid_lazy(self, node: Node) -> None:
"""
Return a pyramid of dask data, where the highest resolution is the
stitched full-resolution images.
"""
self.first_field = "0"
self.plate_data = self.lookup("plate", {})

LOGGER.info("plate_data: %s", self.plate_data)
self.rows = self.plate_data.get("rows")
self.columns = self.plate_data.get("columns")
self.first_field = "0"
self.row_names = [row["name"] for row in self.rows]
self.col_names = [col["name"] for col in self.columns]

Expand All @@ -486,40 +481,50 @@ def get_pyramid_lazy(self, node: Node) -> None:
self.row_count = len(self.rows)
self.column_count = len(self.columns)

# Get the first well...
well_zarr = self.zarr.create(self.well_paths[0])
well_node = Node(well_zarr, node)
well_spec: Optional[Well] = well_node.first(Well)
if well_spec is None:
raise Exception("could not find first well")
self.numpy_type = well_spec.numpy_type
img_path = self.get_image_path(self.well_paths[0])
if not img_path:
# E.g. PlateLabels subclass has no Labels
return
image_zarr = self.zarr.create(img_path)
# Create a Node for image, with no 'root'
self.first_well_image = Node(image_zarr, [])

LOGGER.debug(f"img_pyramid_shapes: {well_spec.img_pyramid_shapes}")
self.get_pyramid_lazy(node)

self.axes = well_spec.img_metadata["axes"]
def get_pyramid_lazy(self, node: Node) -> None:
"""
Return a pyramid of dask data, where the highest resolution is the
stitched full-resolution images.
"""

# Use the first well for dtype and shapes
img_data = self.first_well_image.data
img_pyramid_shapes = [d.shape for d in img_data]
level = 0
self.numpy_type = img_data[level].dtype

LOGGER.debug(f"img_pyramid_shapes: {img_pyramid_shapes}")

# Create a dask pyramid for the plate
pyramid = []
for level, tile_shape in enumerate(well_spec.img_pyramid_shapes):
for level, tile_shape in enumerate(img_pyramid_shapes):
lazy_plate = self.get_stitched_grid(level, tile_shape)
pyramid.append(lazy_plate)

# Set the node.data to be pyramid view of the plate
node.data = pyramid
# Use the first image's metadata for viewing the whole Plate
node.metadata = well_spec.img_metadata
node.metadata = self.first_well_image.metadata

# "metadata" dict gets added to each 'plate' layer in napari
node.metadata.update({"metadata": {"plate": self.plate_data}})

def get_numpy_type(self, image_node: Node) -> np.dtype:
return image_node.data[0].dtype
def get_image_path(self, well_path: str) -> Optional[str]:
return f"{well_path}/{self.first_field}/"

def get_tile_path(self, level: int, row: int, col: int) -> str:
return (
f"{self.row_names[row]}/"
f"{self.col_names[col]}/{self.first_field}/{level}"
)
well_path = f"{self.row_names[row]}/{self.col_names[col]}"
return f"{self.get_image_path(well_path)}{level}/"

def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array:
LOGGER.debug(f"get_stitched_grid() level: {level}, tile_shape: {tile_shape}")
Expand Down Expand Up @@ -550,53 +555,51 @@ def get_tile(tile_name: str) -> np.ndarray:
lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type
)
lazy_row.append(lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1))
return da.concatenate(lazy_rows, axis=len(self.axes) - 2)
lazy_rows.append(da.concatenate(lazy_row, axis=len(tile_shape) - 1))
return da.concatenate(lazy_rows, axis=len(tile_shape) - 2)


class PlateLabels(Plate):
def get_tile_path(self, level: int, row: int, col: int) -> str: # pragma: no cover
"""251.zarr/A/1/0/labels/0/3/"""
path = (
f"{self.row_names[row]}/{self.col_names[col]}/"
f"{self.first_field}/labels/0/{level}"
)
return path

def get_pyramid_lazy(self, node: Node) -> None: # pragma: no cover
super().get_pyramid_lazy(node)
# pyramid data may be multi-channel, but we only have 1 labels channel
# TODO: when PlateLabels are re-enabled, update the logic to handle
# 0.4 axes (list of dictionaries)
if "c" in self.axes:
c_index = self.axes.index("c")
idx = [slice(None)] * len(self.axes)
idx[c_index] = slice(0, 1)
node.data[0] = node.data[0][tuple(idx)]
def __init__(self, node: Node) -> None:
# cache well/image/labels/.zattrs for first field of each well. Key is e.g. A/1
self.well_labels_zattrs: Dict[str, Dict] = {}
super().__init__(node)

# remove image metadata
node.metadata = {}
# node.metadata = {}

# combine 'properties' from each image
# from https://github.com/ome/ome-zarr-py/pull/61/
properties: Dict[int, Dict[str, Any]] = {}
for row in self.row_names:
for col in self.col_names:
path = f"{row}/{col}/{self.first_field}/labels/0/.zattrs"
labels_json = self.zarr.get_json(path).get("image-label", {})
# NB: assume that 'label_val' is unique across all images
props_list = labels_json.get("properties", [])
if props_list:
for props in props_list:
label_val = props["label-value"]
properties[label_val] = dict(props)
del properties[label_val]["label-value"]
for well_path in self.well_paths:
path = self.get_image_path(well_path)
if not path:
continue
labels_json = self.zarr.get_json(path + ".zattrs").get("image-label", {})
# NB: assume that 'label_val' is unique across all images
props_list = labels_json.get("properties", [])
if props_list:
for props in props_list:
label_val = props["label-value"]
properties[label_val] = dict(props)
del properties[label_val]["label-value"]
node.metadata["properties"] = properties

def get_numpy_type(self, image_node: Node) -> np.dtype: # pragma: no cover
# FIXME - don't assume Well A1 is valid
path = self.get_tile_path(0, 0, 0)
label_zarr = self.zarr.load(path)
return label_zarr.dtype
def get_image_path(self, well_path: str) -> Optional[str]:
"""Returns path to .zattr for Well labels, e.g. /A/1/0/labels/my_cells/"""
labels_attrs = self.well_labels_zattrs.get(well_path)
if labels_attrs is None:
# if not cached, load...
path = f"{well_path}/{self.first_field}/labels/"
LOGGER.info("loading labels/.zattrs: %s.zattrs", path)
first_field_labels = self.zarr.create(path)
# loads labels/.zattrs when new ZarrLocation is created
labels_attrs = first_field_labels.root_attrs
self.well_labels_zattrs[well_path] = labels_attrs
label_paths = labels_attrs.get("labels", [])
if len(label_paths) > 0:
return f"{well_path}/{self.first_field}/labels/{label_paths[0]}/"
return None


class Reader:
Expand Down
6 changes: 2 additions & 4 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def test_minimal_plate(self):

reader = Reader(parse_url(str(self.path)))
nodes = list(reader())
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1
assert len(nodes) == 2
assert len(nodes[0].specs) == 1
assert isinstance(nodes[0].specs[0], Plate)
# assert len(nodes[1].specs) == 1
Expand All @@ -73,8 +72,7 @@ def test_multiwells_plate(self):

reader = Reader(parse_url(str(self.path)))
nodes = list(reader())
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1
assert len(nodes) == 2
assert len(nodes[0].specs) == 1
assert isinstance(nodes[0].specs[0], Plate)
# assert len(nodes[1].specs) == 1
Expand Down