Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plate labels fix #207

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 103 additions & 68 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
import os
from abc import ABC
from typing import Any, Dict, Iterator, List, Optional, Type, Union, cast, overload

Expand All @@ -26,7 +27,7 @@ def __init__(
zarr: ZarrLocation,
root: Union["Node", "Reader", List[ZarrLocation]],
visibility: bool = True,
plate_labels: bool = False,
# plate_labels: bool = False,
):
self.zarr = zarr
self.root = root
Expand All @@ -53,11 +54,11 @@ def __init__(
self.specs.append(Multiscales(self))
if OMERO.matches(zarr):
self.specs.append(OMERO(self))
if plate_labels:
# if plate_labels:
if PlateLabels.matches(zarr):
self.specs.append(PlateLabels(self))
elif Plate.matches(zarr):
self.specs.append(Plate(self))
# self.add(zarr, plate_labels=True)
if Well.matches(zarr):
self.specs.append(Well(self))

Expand Down Expand Up @@ -136,7 +137,7 @@ def add(
visibility = self.visible

self.seen.append(zarr)
node = Node(zarr, self, visibility=visibility, plate_labels=plate_labels)
node = Node(zarr, self, visibility=visibility)
if prepend:
self.pre_nodes.append(node)
else:
Expand Down Expand Up @@ -474,19 +475,18 @@ def matches(zarr: ZarrLocation) -> bool:

def __init__(self, node: Node) -> None:
super().__init__(node)

LOGGER.debug("Plate created with ZarrLocation fmt: %s", self.zarr.fmt)
self.get_pyramid_lazy(node)

def get_pyramid_lazy(self, node: Node) -> None:
"""
Return a pyramid of dask data, where the highest resolution is the
stitched full-resolution images.
"""
self.plate_data = self.lookup("plate", {})
self.first_field = "0"
# For Plate, plate_zarr is same as self.zarr, but for PlateLabels
# (node at /plate.zarr/labels) this is the parent at /plate.zarr node.
self.plate_zarr = self.get_plate_zarr()
self.plate_data = self.plate_zarr.root_attrs.get("plate", {})

LOGGER.info("plate_data: %s", self.plate_data)
self.rows = self.plate_data.get("rows")
self.columns = self.plate_data.get("columns")
self.first_field = "0"
self.row_names = [row["name"] for row in self.rows]
self.col_names = [col["name"] for col in self.columns]

Expand All @@ -496,40 +496,59 @@ def get_pyramid_lazy(self, node: Node) -> None:
self.row_count = len(self.rows)
self.column_count = len(self.columns)

# Get the first well...
well_zarr = self.zarr.create(self.well_paths[0])
well_node = Node(well_zarr, node)
well_spec: Optional[Well] = well_node.first(Well)
if well_spec is None:
raise Exception("Could not find first well")
self.numpy_type = well_spec.numpy_type
img_path = self.get_image_path(self.well_paths[0])
if not img_path:
# E.g. PlateLabels subclass has no Labels
return
image_zarr = self.plate_zarr.create(img_path)
# Create a Node for image, with no 'root'
self.first_well_image = Node(image_zarr, [])

self.get_pyramid_lazy(node)

LOGGER.debug("img_pyramid_shapes: %s", well_spec.img_pyramid_shapes)
# Load possible node data IF this is a Plate
if Plate.matches(self.zarr):
child_zarr = self.zarr.create("labels")
# This is a 'virtual' path to plate.zarr/labels
node.add(child_zarr)

self.axes = well_spec.img_metadata["axes"]
def get_plate_zarr(self) -> ZarrLocation:
return self.zarr

def get_pyramid_lazy(self, node: Node) -> None:
"""
Return a pyramid of dask data, where the highest resolution is the
stitched full-resolution images.
"""

# Use the first well for dtype and shapes
img_data = self.first_well_image.data
img_pyramid_shapes = [d.shape for d in img_data]
level = 0

LOGGER.debug("img_pyramid_shapes: %s", img_pyramid_shapes)

# Create a dask pyramid for the plate
pyramid = []
for level, tile_shape in enumerate(well_spec.img_pyramid_shapes):
for level, tile_shape in enumerate(img_pyramid_shapes):
self.numpy_type = img_data[level].dtype
lazy_plate = self.get_stitched_grid(level, tile_shape)
pyramid.append(lazy_plate)

# Set the node.data to be pyramid view of the plate
node.data = pyramid
# Use the first image's metadata for viewing the whole Plate
node.metadata = well_spec.img_metadata
node.metadata = self.first_well_image.metadata

# "metadata" dict gets added to each 'plate' layer in napari
node.metadata.update({"metadata": {"plate": self.plate_data}})

def get_numpy_type(self, image_node: Node) -> np.dtype:
return image_node.data[0].dtype
def get_image_path(self, well_path: str) -> Optional[str]:
return f"{well_path}/{self.first_field}/"

def get_tile_path(self, level: int, row: int, col: int) -> str:
return (
f"{self.row_names[row]}/"
f"{self.col_names[col]}/{self.first_field}/{level}"
)
well_path = f"{self.row_names[row]}/{self.col_names[col]}"
return f"{self.get_image_path(well_path)}{level}/"

def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array:
LOGGER.debug("get_stitched_grid() level: %s, tile_shape: %s", level, tile_shape)
Expand All @@ -541,9 +560,9 @@ def get_tile(tile_name: str) -> np.ndarray:
LOGGER.debug("LOADING tile... %s with shape: %s", path, tile_shape)

try:
data = self.zarr.load(path)
data = self.plate_zarr.load(path)
except ValueError:
LOGGER.exception("Failed to load %s", path)
LOGGER.error("Failed to load %s", path)
data = np.zeros(tile_shape, dtype=self.numpy_type)
return data

Expand All @@ -559,53 +578,69 @@ def get_tile(tile_name: str) -> np.ndarray:
lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type
)
lazy_row.append(lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1))
return da.concatenate(lazy_rows, axis=len(self.axes) - 2)
lazy_rows.append(da.concatenate(lazy_row, axis=len(tile_shape) - 1))
return da.concatenate(lazy_rows, axis=len(tile_shape) - 2)


class PlateLabels(Plate):
def get_tile_path(self, level: int, row: int, col: int) -> str: # pragma: no cover
"""251.zarr/A/1/0/labels/0/3/"""
path = (
f"{self.row_names[row]}/{self.col_names[col]}/"
f"{self.first_field}/labels/0/{level}"
)
return path

def get_pyramid_lazy(self, node: Node) -> None: # pragma: no cover
super().get_pyramid_lazy(node)
# pyramid data may be multi-channel, but we only have 1 labels channel
# TODO: when PlateLabels are re-enabled, update the logic to handle
# 0.4 axes (list of dictionaries)
if "c" in self.axes:
c_index = self.axes.index("c")
idx = [slice(None)] * len(self.axes)
idx[c_index] = slice(0, 1)
node.data[0] = node.data[0][tuple(idx)]
@staticmethod
def matches(zarr: ZarrLocation) -> bool:
# If the path ends in plate/labels...
if not zarr.path.endswith("labels"):
return False

# and the parent is a plate
parent_path = os.path.dirname(zarr.path)
parent = zarr.create(parent_path)
return "plate" in parent.root_attrs

def __init__(self, node: Node) -> None:
# cache well/image/labels/.zattrs for first field of each well. Key is e.g. A/1
self.well_labels_zattrs: Dict[str, Dict] = {}
super().__init__(node)

# remove image metadata
node.metadata = {}
# node.metadata = {}

# combine 'properties' from each image
# from https://github.com/ome/ome-zarr-py/pull/61/
properties: Dict[int, Dict[str, Any]] = {}
for row in self.row_names:
for col in self.col_names:
path = f"{row}/{col}/{self.first_field}/labels/0/.zattrs"
labels_json = self.zarr.get_json(path).get("image-label", {})
# NB: assume that 'label_val' is unique across all images
props_list = labels_json.get("properties", [])
if props_list:
for props in props_list:
label_val = props["label-value"]
properties[label_val] = dict(props)
del properties[label_val]["label-value"]
for well_path in self.well_paths:
path = self.get_image_path(well_path)
if not path:
continue
labels_json = self.zarr.get_json(path + ".zattrs").get("image-label", {})
# NB: assume that 'label_val' is unique across all images
props_list = labels_json.get("properties", [])
if props_list:
for props in props_list:
label_val = props["label-value"]
properties[label_val] = dict(props)
del properties[label_val]["label-value"]
node.metadata["properties"] = properties

def get_numpy_type(self, image_node: Node) -> np.dtype: # pragma: no cover
# FIXME - don't assume Well A1 is valid
path = self.get_tile_path(0, 0, 0)
label_zarr = self.zarr.load(path)
return label_zarr.dtype
def get_plate_zarr(self) -> ZarrLocation:
# lookup parent plate, remove the /labels
parent_path = os.path.dirname(self.zarr.path)
return self.zarr.create(parent_path)

def get_image_path(self, well_path: str) -> Optional[str]:
"""Returns path to .zattr for Well labels, e.g. /A/1/0/labels/my_cells/"""
labels_attrs = self.well_labels_zattrs.get(well_path)
if labels_attrs is None:
# if not cached, load...
path = f"{well_path}/{self.first_field}/labels/"
LOGGER.info("loading labels/.zattrs: %s.zattrs", path)
plate_zarr = self.get_plate_zarr()
first_field_labels = plate_zarr.create(path)
# loads labels/.zattrs when new ZarrLocation is created
labels_attrs = first_field_labels.root_attrs
self.well_labels_zattrs[well_path] = labels_attrs
label_paths = labels_attrs.get("labels", [])
LOGGER.debug("label_paths: %s", label_paths)
if len(label_paths) > 0:
return f"{well_path}/{self.first_field}/labels/{label_paths[0]}/"
return None


class Reader:
Expand Down
28 changes: 19 additions & 9 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@

from ome_zarr.data import create_zarr
from ome_zarr.io import parse_url
from ome_zarr.reader import Node, Plate, Reader
from ome_zarr.writer import write_image, write_plate_metadata, write_well_metadata
from ome_zarr.reader import Node, Plate, PlateLabels, Reader
from ome_zarr.writer import (
write_image,
write_labels,
write_plate_metadata,
write_well_metadata,
)


class TestReader:
Expand Down Expand Up @@ -68,8 +73,7 @@ def test_minimal_plate(self):

reader = Reader(parse_url(str(self.path)))
nodes = list(reader())
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1
assert len(nodes) == 2
assert len(nodes[0].specs) == 1
assert isinstance(nodes[0].specs[0], Plate)
# assert len(nodes[1].specs) == 1
Expand All @@ -87,13 +91,19 @@ def test_multiwells_plate(self):
write_well_metadata(well, ["0", "1", "2"])
for field in range(3):
image = well.require_group(str(field))
write_image(zeros((1, 1, 1, 256, 256)), image)
write_image(zeros((256, 256)), image)

write_labels(zeros((256, 256)), image, name="test_labels")

reader = Reader(parse_url(str(self.path)))
nodes = list(reader())
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1
assert len(nodes) == 2
assert len(nodes[0].specs) == 1
assert isinstance(nodes[0].specs[0], Plate)
# assert len(nodes[1].specs) == 1
# assert isinstance(nodes[1].specs[0], PlateLabels)
assert len(nodes[1].specs) == 1
assert isinstance(nodes[1].specs[0], PlateLabels)
# plate shape is the single image * grid dimensions
plate_shape = (256 * len(row_names), 256 * len(col_names))
# check largest data for image and labels
assert nodes[0].data[0].shape == plate_shape
assert nodes[1].data[0].shape == plate_shape