Skip to content

Commit

Permalink
Cache (#180)
Browse files Browse the repository at this point in the history
* Cache tiles
  • Loading branch information
erikogabrielsson authored Nov 22, 2024
1 parent 269e8f8 commit 73f7759
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 53 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Cache for compressed and decompressed tiles.

### Changed

- Default to placing the image at the middle of the slide if no `TotalPixelMatrixOriginSequence` is set when producing DICOM metadata .
Expand Down
4 changes: 2 additions & 2 deletions tests/testdata/region_definitions.json
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@
"levels": 3,
"label": "6634ad9dbe8c8f074266e21ef8eb6c12",
"overview": "3d792eb3441c58c5d881cb0cf21a397e",
"level_transfer_syntax": " 1.2.840.10008.1.2.4.91",
"level_transfer_syntax": "1.2.840.10008.1.2.4.91",
"label_transfer_syntax": "1.2.840.10008.1.2.1",
"overview_transfer_syntax": "1.2.840.10008.1.2.1",
"tiled": "sparse",
Expand Down Expand Up @@ -552,7 +552,7 @@
"label": "5a9a991e350f4fe29af08b7e3bcc56df",
"overview": "c3d8f41772dd26a19121f4f87d9f520c",
"tiled": "sparse",
"level_transfer_syntax": " 1.2.840.10008.1.2.4.91",
"level_transfer_syntax": "1.2.840.10008.1.2.4.91",
"label_transfer_syntax": "1.2.840.10008.1.2.1",
"overview_transfer_syntax": "1.2.840.10008.1.2.1",
"read_region": [
Expand Down
115 changes: 115 additions & 0 deletions wsidicom/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024 SECTRA AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterator
from threading import Lock
from typing import Callable, Dict, Generic, Iterable, Optional, Sequence, Tuple, TypeVar

from PIL.Image import Image

CacheKeyType = TypeVar("CacheKeyType")
CacheItemType = TypeVar("CacheItemType")


class LRU(Generic[CacheKeyType, CacheItemType]):
def __init__(self, maxsize: int):
self._lock = Lock()
self._cache: Dict[CacheKeyType, CacheItemType] = {}
self._maxsize = maxsize

@property
def maxsize(self) -> int:
return self._maxsize

def get(self, key: CacheKeyType) -> Optional[CacheItemType]:
with self._lock:
item = self._cache.pop(key, None)
if item is not None:
self._cache[key] = item
return item

def put(self, key: CacheKeyType, value: CacheItemType) -> None:
with self._lock:
self._cache[key] = value
if len(self._cache) > self._maxsize:
self._cache.pop(next(iter(self._cache)))

def clear(self) -> None:
with self._lock:
self._cache.clear()

def resize(self, maxsize: int) -> None:
with self._lock:
self._maxsize = maxsize
if len(self._cache) > maxsize:
for _ in range(len(self._cache) - maxsize):
self._cache.pop(next(iter(self._cache)))


class FrameCache(Generic[CacheItemType]):
def __init__(self, size: int):
self._size = size
self._lru_cache = LRU[Tuple[int, int], CacheItemType](size)

def get_tile_frame(
self,
image_data_id: int,
frame_index: int,
frame_getter: Callable[[int], CacheItemType],
) -> CacheItemType:
if self._lru_cache.maxsize < 1:
return frame_getter(frame_index)
frame = self._lru_cache.get((image_data_id, frame_index))
if frame is None:
frame = frame_getter(frame_index)
self._lru_cache.put((image_data_id, frame_index), frame)
return frame

def get_tile_frames(
self,
image_data_id: int,
frame_indices: Sequence[int],
frames_getter: Callable[[Iterable[int]], Iterator[CacheItemType]],
) -> Iterator[CacheItemType]:
if self._lru_cache.maxsize < 1:
return frames_getter(frame_indices)
cached_frames = {
frame_index: frame
for (frame_index, frame) in (
(frame_index, self._lru_cache.get((image_data_id, frame_index)))
for frame_index in frame_indices
)
if frame is not None
}
fetched_frames = frames_getter(
frame_index
for frame_index in frame_indices
if frame_index not in cached_frames
)
for frame_index in frame_indices:
frame = cached_frames.get(frame_index)
if frame is None:
frame = next(fetched_frames)
self._lru_cache.put((image_data_id, frame_index), frame)
yield frame

def clear(self) -> None:
self._lru_cache.clear()

def resize(self, size: int) -> None:
self._lru_cache.resize(size)


EncodedFrameCache = FrameCache[bytes]
DecodedFrameCache = FrameCache[Image]
20 changes: 20 additions & 0 deletions wsidicom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self) -> None:
self._strict_specimen_identifier_check = True
self._ignore_specimen_preparation_step_on_validation_error = True
self._truncate_long_dicom_strings = False
self._decoded_frame_cache_size = 1000
self._encoded_frame_cache_size = 1000

@property
def strict_uid_check(self) -> bool:
Expand Down Expand Up @@ -134,6 +136,24 @@ def truncate_long_dicom_strings_on_validation_error(self) -> bool:
def truncate_long_dicom_strings_on_validation_error(self, value: bool) -> None:
self._truncate_long_dicom_strings = value

@property
def decoded_frame_cache_size(self) -> int:
"""Size of the decoded frame cache."""
return self._decoded_frame_cache_size

@decoded_frame_cache_size.setter
def decoded_frame_cache_size(self, value: int) -> None:
self._decoded_frame_cache_size = value

@property
def encoded_frame_cache_size(self) -> int:
"""Size of the encoded frame cache."""
return self._encoded_frame_cache_size

@encoded_frame_cache_size.setter
def encoded_frame_cache_size(self, value: int) -> None:
self._encoded_frame_cache_size = value


settings = Settings()
"""Global settings variable."""
13 changes: 11 additions & 2 deletions wsidicom/file/wsidicom_file_image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydicom.uid import UID
from upath import UPath

from wsidicom.cache import DecodedFrameCache, EncodedFrameCache
from wsidicom.codec import Codec
from wsidicom.errors import WsiDicomNotFoundError
from wsidicom.file.io import WsiDicomReader
Expand All @@ -32,7 +33,10 @@ class WsiDicomFileImageData(WsiDicomImageData):
"""

def __init__(
self, readers: Union[WsiDicomReader, Sequence[WsiDicomReader]]
self,
readers: Union[WsiDicomReader, Sequence[WsiDicomReader]],
decoded_frame_cache: DecodedFrameCache,
encoded_frame_cache: EncodedFrameCache,
) -> None:
"""
Create WsiDicomFileImageData from frame data from readers.
Expand All @@ -59,7 +63,12 @@ def __init__(
dataset.tile_size,
dataset.photometric_interpretation,
)
super().__init__([file.dataset for file in self._readers.values()], codec)
super().__init__(
[file.dataset for file in self._readers.values()],
codec,
decoded_frame_cache,
encoded_frame_cache,
)

def __repr__(self) -> str:
return f"{type(self).__name__}({self._readers.values()})"
Expand Down
12 changes: 7 additions & 5 deletions wsidicom/file/wsidicom_file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, streams: Iterable[WsiDicomIO]) -> None:
streams: Iterable[WsiDicomIO]
Opened streams to read from.
"""
super().__init__()
self._levels: List[WsiDicomReader] = []
self._labels: List[WsiDicomReader] = []
self._overviews: List[WsiDicomReader] = []
Expand Down Expand Up @@ -244,9 +245,8 @@ def _get_base_dataset(files: Iterable[WsiDicomReader]) -> WsiDataset:
)
)

@classmethod
def _create_instances(
cls,
self,
files: Iterable[WsiDicomReader],
series_uids: SlideUids,
series_tile_size: Optional[Size] = None,
Expand All @@ -271,12 +271,14 @@ def _create_instances(
Iterable[WsiInstancece]
Iterable of created instances.
"""
filtered_files = cls._filter_files(files, series_uids, series_tile_size)
files_grouped_by_instance = cls._group_files(filtered_files)
filtered_files = self._filter_files(files, series_uids, series_tile_size)
files_grouped_by_instance = self._group_files(filtered_files)
return (
WsiInstance(
[file.dataset for file in instance_files],
WsiDicomFileImageData(instance_files),
WsiDicomFileImageData(
instance_files, self._decoded_frame_cache, self._encoded_frame_cache
),
)
for instance_files in files_grouped_by_instance.values()
)
Expand Down
9 changes: 8 additions & 1 deletion wsidicom/file/wsidicom_file_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
from pydicom.valuerep import MAX_VALUE_LEN
from upath import UPath

from wsidicom.cache import DecodedFrameCache, EncodedFrameCache
from wsidicom.codec import Encoder
from wsidicom.codec import Settings as EncoderSettings
from wsidicom.config import settings
from wsidicom.file.io import (
OffsetTableType,
WsiDicomReader,
Expand Down Expand Up @@ -101,6 +103,8 @@ def __init__(
self._filepaths: List[UPath] = []
self._opened_files: List[WsiDicomReader] = []
self._file_options = file_options
self._decoded_frame_cache = DecodedFrameCache(settings.decoded_frame_cache_size)
self._encoded_frame_cache = EncodedFrameCache(settings.encoded_frame_cache_size)
super().__init__(
uid_generator,
workers,
Expand Down Expand Up @@ -273,7 +277,10 @@ def _open_files(self, filepaths: Iterable[UPath]) -> List[WsiInstance]:
self._opened_files.extend(readers)
return [
WsiInstance(
[reader.dataset for reader in readers], WsiDicomFileImageData(readers)
[reader.dataset for reader in readers],
WsiDicomFileImageData(
readers, self._decoded_frame_cache, self._encoded_frame_cache
),
)
]

Expand Down
30 changes: 4 additions & 26 deletions wsidicom/instance/image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,28 +250,6 @@ def blank_encoded_tile(self) -> bytes:
def encoder(self) -> Encoder:
return self._encoder

def get_decoded_tiles(
self, tiles: Iterable[Point], z: float, path: str
) -> Iterator[Image]:
"""
Return Pillow images for tiles.
Parameters
----------
tiles: Iterable[Point]
Tiles to get.
z: float
Z coordinate.
path: str
Optical path.
Returns
-------
Iterator[Image]
Tiles as Images.
"""
return self._get_decoded_tiles(tiles, z, path)

def get_encoded_tiles(
self,
tiles: Iterable[Point],
Expand Down Expand Up @@ -392,7 +370,7 @@ def get_tile(self, tile_point: Point, z: float, path: str) -> Image:
def get_tiles(self, tiles: Iterable[Point], z: float, path: str) -> Iterator[Image]:
return (
self._crop_tile(tile_point, tile)
for tile_point, tile in zip(tiles, self.get_decoded_tiles(tiles, z, path))
for tile_point, tile in zip(tiles, self._get_decoded_tiles(tiles, z, path))
)

def get_scaled_encoded_tile(
Expand Down Expand Up @@ -446,7 +424,7 @@ def get_encoded_tile(self, tile: Point, z: float, path: str) -> bytes:
if cropped_tile_region.size == self.tile_size:
return self._get_encoded_tile(tile, z, path)
image = self._get_decoded_tile(tile, z, path)
image.crop(box=cropped_tile_region.box_from_origin)
image = image.crop(box=cropped_tile_region.box_from_origin)
return self.encoder.encode(image)

def stitch_tiles(self, region: Region, path: str, z: float, threads: int) -> Image:
Expand Down Expand Up @@ -579,7 +557,7 @@ def _crop_tile(self, tile_point: Point, tile: Image) -> Image:
# Check if tile is an edge tile that should be cropped
if tile_crop.size != self.tile_size:
return tile.crop(box=tile_crop.box)
return tile
return tile.copy()

def _paste_tiles(
self,
Expand Down Expand Up @@ -616,7 +594,7 @@ def _paste_tiles(
def thread_paste(tile_points: Iterable[Point]) -> None:
tile_points = list(tile_points)
for tile_point, tile in zip(
tile_points, self.get_decoded_tiles(tile_points, z, path)
tile_points, self._get_decoded_tiles(tile_points, z, path)
):
paste_method(image, tile_point, tile)

Expand Down
Loading

0 comments on commit 73f7759

Please sign in to comment.