diff --git a/iohub/daxi.py b/iohub/daxi.py index 5659f3e0..8861d454 100644 --- a/iohub/daxi.py +++ b/iohub/daxi.py @@ -48,14 +48,22 @@ def __init__( self._missing_value = missing_value self._dtype = np.uint16 - self._channels = self._metadata[self._CHANNELS_KEY] + self._wavelengths = self._metadata[self._CHANNELS_KEY] shape_dict = self._metadata[self._SHAPE_KEY] + self._channels = [ + f"v{v}_c{wl}" + for v in range(shape_dict["V"]) + for wl in self._wavelengths + ] + self._raw_shape = tuple(shape_dict[k] for k in self._SHAPE_IDX) - shape_dict["Z"] //= len(self._channels) - shape_dict["V"] *= len(self._channels) + shape_dict["Z"] //= len(self._wavelengths) + shape_dict["V"] *= len(self._wavelengths) + # y and x are flipped in comparison to DaXi file format + shape_dict["Y"], shape_dict["X"] = shape_dict["X"], shape_dict["Y"] self._shape = tuple(shape_dict[k] for k in self._SHAPE_IDX) @@ -70,7 +78,7 @@ def __init__( shape=self._shape[2:], dtype=self._dtype, ) - for c in range(len(self._channels)) + for c in range(len(self._wavelengths)) ] ) for v in range(self._raw_shape[1]) @@ -107,12 +115,14 @@ def _load_volume(self, t: int, v: int, c: int) -> np.ndarray: self._shape[2:], self._missing_value, dtype=self._dtype ) - return np.memmap( + arr = np.memmap( self._volume_path(t, v), dtype=self._dtype, shape=self._raw_shape[2:], mode="r", - )[c :: len(self._channels)] + )[c :: len(self._wavelengths)] + # inverting y and x and flipping new x to match original DaXi format + return np.flip(arr.transpose((0, 2, 1)), axis=-1) def _volume_path(self, t: int, v: int) -> Path: """ @@ -190,6 +200,12 @@ def scale(self) -> list[float]: 0.439, ] + @staticmethod + def is_valid_path(path: "StrOrBytesPath") -> bool: + """Check if a path is a valid DaXi dataset.""" + path = Path(path) + return path.exists() and (path / "metadata.yaml").exists() + def create_mock_daxi_dataset(path: "StrOrBytesPath") -> None: """ diff --git a/iohub/ngff.py b/iohub/ngff.py index 19b9d29a..8244269f 100644 --- a/iohub/ngff.py +++ b/iohub/ngff.py @@ -940,7 +940,7 @@ def initialize_pyramid(self, levels: int) -> None: for tr in transforms: if tr.type == "scale": for i in range(len(tr.scale))[-3:]: - tr.scale[i] /= factor + tr.scale[i] *= factor self.create_zeros( name=str(level),