From 73f775993d7d5ee19d83d1c1bf59dae0a6fc494b Mon Sep 17 00:00:00 2001 From: Erik O Gabrielsson <83275777+erikogabrielsson@users.noreply.github.com> Date: Fri, 22 Nov 2024 22:24:33 +0100 Subject: [PATCH] Cache (#180) * Cache tiles --- CHANGELOG.md | 4 + tests/testdata/region_definitions.json | 4 +- wsidicom/cache.py | 115 ++++++++++++++++++++++ wsidicom/config.py | 20 ++++ wsidicom/file/wsidicom_file_image_data.py | 13 ++- wsidicom/file/wsidicom_file_source.py | 12 ++- wsidicom/file/wsidicom_file_target.py | 9 +- wsidicom/instance/image_data.py | 30 +----- wsidicom/instance/wsidicom_image_data.py | 26 ++++- wsidicom/source.py | 16 +++ wsidicom/web/wsidicom_web_image_data.py | 31 ++++-- wsidicom/web/wsidicom_web_source.py | 10 +- wsidicom/wsidicom.py | 8 ++ 13 files changed, 245 insertions(+), 53 deletions(-) create mode 100644 wsidicom/cache.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d0b903ff..0674460f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Cache for compressed and decompressed tiles. + ### Changed - Default to placing the image at the middle of the slide if no `TotalPixelMatrixOriginSequence` is set when producing DICOM metadata . diff --git a/tests/testdata/region_definitions.json b/tests/testdata/region_definitions.json index 54d80f7d..a6a599b1 100644 --- a/tests/testdata/region_definitions.json +++ b/tests/testdata/region_definitions.json @@ -467,7 +467,7 @@ "levels": 3, "label": "6634ad9dbe8c8f074266e21ef8eb6c12", "overview": "3d792eb3441c58c5d881cb0cf21a397e", - "level_transfer_syntax": " 1.2.840.10008.1.2.4.91", + "level_transfer_syntax": "1.2.840.10008.1.2.4.91", "label_transfer_syntax": "1.2.840.10008.1.2.1", "overview_transfer_syntax": "1.2.840.10008.1.2.1", "tiled": "sparse", @@ -552,7 +552,7 @@ "label": "5a9a991e350f4fe29af08b7e3bcc56df", "overview": "c3d8f41772dd26a19121f4f87d9f520c", "tiled": "sparse", - "level_transfer_syntax": " 1.2.840.10008.1.2.4.91", + "level_transfer_syntax": "1.2.840.10008.1.2.4.91", "label_transfer_syntax": "1.2.840.10008.1.2.1", "overview_transfer_syntax": "1.2.840.10008.1.2.1", "read_region": [ diff --git a/wsidicom/cache.py b/wsidicom/cache.py new file mode 100644 index 00000000..15ba9b0c --- /dev/null +++ b/wsidicom/cache.py @@ -0,0 +1,115 @@ +# Copyright 2024 SECTRA AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterator +from threading import Lock +from typing import Callable, Dict, Generic, Iterable, Optional, Sequence, Tuple, TypeVar + +from PIL.Image import Image + +CacheKeyType = TypeVar("CacheKeyType") +CacheItemType = TypeVar("CacheItemType") + + +class LRU(Generic[CacheKeyType, CacheItemType]): + def __init__(self, maxsize: int): + self._lock = Lock() + self._cache: Dict[CacheKeyType, CacheItemType] = {} + self._maxsize = maxsize + + @property + def maxsize(self) -> int: + return self._maxsize + + def get(self, key: CacheKeyType) -> Optional[CacheItemType]: + with self._lock: + item = self._cache.pop(key, None) + if item is not None: + self._cache[key] = item + return item + + def put(self, key: CacheKeyType, value: CacheItemType) -> None: + with self._lock: + self._cache[key] = value + if len(self._cache) > self._maxsize: + self._cache.pop(next(iter(self._cache))) + + def clear(self) -> None: + with self._lock: + self._cache.clear() + + def resize(self, maxsize: int) -> None: + with self._lock: + self._maxsize = maxsize + if len(self._cache) > maxsize: + for _ in range(len(self._cache) - maxsize): + self._cache.pop(next(iter(self._cache))) + + +class FrameCache(Generic[CacheItemType]): + def __init__(self, size: int): + self._size = size + self._lru_cache = LRU[Tuple[int, int], CacheItemType](size) + + def get_tile_frame( + self, + image_data_id: int, + frame_index: int, + frame_getter: Callable[[int], CacheItemType], + ) -> CacheItemType: + if self._lru_cache.maxsize < 1: + return frame_getter(frame_index) + frame = self._lru_cache.get((image_data_id, frame_index)) + if frame is None: + frame = frame_getter(frame_index) + self._lru_cache.put((image_data_id, frame_index), frame) + return frame + + def get_tile_frames( + self, + image_data_id: int, + frame_indices: Sequence[int], + frames_getter: Callable[[Iterable[int]], Iterator[CacheItemType]], + ) -> Iterator[CacheItemType]: + if self._lru_cache.maxsize < 1: + return frames_getter(frame_indices) + cached_frames = { + frame_index: frame + for (frame_index, frame) in ( + (frame_index, self._lru_cache.get((image_data_id, frame_index))) + for frame_index in frame_indices + ) + if frame is not None + } + fetched_frames = frames_getter( + frame_index + for frame_index in frame_indices + if frame_index not in cached_frames + ) + for frame_index in frame_indices: + frame = cached_frames.get(frame_index) + if frame is None: + frame = next(fetched_frames) + self._lru_cache.put((image_data_id, frame_index), frame) + yield frame + + def clear(self) -> None: + self._lru_cache.clear() + + def resize(self, size: int) -> None: + self._lru_cache.resize(size) + + +EncodedFrameCache = FrameCache[bytes] +DecodedFrameCache = FrameCache[Image] diff --git a/wsidicom/config.py b/wsidicom/config.py index fc3401c8..ae5b7f0d 100644 --- a/wsidicom/config.py +++ b/wsidicom/config.py @@ -33,6 +33,8 @@ def __init__(self) -> None: self._strict_specimen_identifier_check = True self._ignore_specimen_preparation_step_on_validation_error = True self._truncate_long_dicom_strings = False + self._decoded_frame_cache_size = 1000 + self._encoded_frame_cache_size = 1000 @property def strict_uid_check(self) -> bool: @@ -134,6 +136,24 @@ def truncate_long_dicom_strings_on_validation_error(self) -> bool: def truncate_long_dicom_strings_on_validation_error(self, value: bool) -> None: self._truncate_long_dicom_strings = value + @property + def decoded_frame_cache_size(self) -> int: + """Size of the decoded frame cache.""" + return self._decoded_frame_cache_size + + @decoded_frame_cache_size.setter + def decoded_frame_cache_size(self, value: int) -> None: + self._decoded_frame_cache_size = value + + @property + def encoded_frame_cache_size(self) -> int: + """Size of the encoded frame cache.""" + return self._encoded_frame_cache_size + + @encoded_frame_cache_size.setter + def encoded_frame_cache_size(self, value: int) -> None: + self._encoded_frame_cache_size = value + settings = Settings() """Global settings variable.""" diff --git a/wsidicom/file/wsidicom_file_image_data.py b/wsidicom/file/wsidicom_file_image_data.py index 94889e3b..3c2eff95 100644 --- a/wsidicom/file/wsidicom_file_image_data.py +++ b/wsidicom/file/wsidicom_file_image_data.py @@ -18,6 +18,7 @@ from pydicom.uid import UID from upath import UPath +from wsidicom.cache import DecodedFrameCache, EncodedFrameCache from wsidicom.codec import Codec from wsidicom.errors import WsiDicomNotFoundError from wsidicom.file.io import WsiDicomReader @@ -32,7 +33,10 @@ class WsiDicomFileImageData(WsiDicomImageData): """ def __init__( - self, readers: Union[WsiDicomReader, Sequence[WsiDicomReader]] + self, + readers: Union[WsiDicomReader, Sequence[WsiDicomReader]], + decoded_frame_cache: DecodedFrameCache, + encoded_frame_cache: EncodedFrameCache, ) -> None: """ Create WsiDicomFileImageData from frame data from readers. @@ -59,7 +63,12 @@ def __init__( dataset.tile_size, dataset.photometric_interpretation, ) - super().__init__([file.dataset for file in self._readers.values()], codec) + super().__init__( + [file.dataset for file in self._readers.values()], + codec, + decoded_frame_cache, + encoded_frame_cache, + ) def __repr__(self) -> str: return f"{type(self).__name__}({self._readers.values()})" diff --git a/wsidicom/file/wsidicom_file_source.py b/wsidicom/file/wsidicom_file_source.py index 0df6c977..5fb49b10 100644 --- a/wsidicom/file/wsidicom_file_source.py +++ b/wsidicom/file/wsidicom_file_source.py @@ -57,6 +57,7 @@ def __init__(self, streams: Iterable[WsiDicomIO]) -> None: streams: Iterable[WsiDicomIO] Opened streams to read from. """ + super().__init__() self._levels: List[WsiDicomReader] = [] self._labels: List[WsiDicomReader] = [] self._overviews: List[WsiDicomReader] = [] @@ -244,9 +245,8 @@ def _get_base_dataset(files: Iterable[WsiDicomReader]) -> WsiDataset: ) ) - @classmethod def _create_instances( - cls, + self, files: Iterable[WsiDicomReader], series_uids: SlideUids, series_tile_size: Optional[Size] = None, @@ -271,12 +271,14 @@ def _create_instances( Iterable[WsiInstancece] Iterable of created instances. """ - filtered_files = cls._filter_files(files, series_uids, series_tile_size) - files_grouped_by_instance = cls._group_files(filtered_files) + filtered_files = self._filter_files(files, series_uids, series_tile_size) + files_grouped_by_instance = self._group_files(filtered_files) return ( WsiInstance( [file.dataset for file in instance_files], - WsiDicomFileImageData(instance_files), + WsiDicomFileImageData( + instance_files, self._decoded_frame_cache, self._encoded_frame_cache + ), ) for instance_files in files_grouped_by_instance.values() ) diff --git a/wsidicom/file/wsidicom_file_target.py b/wsidicom/file/wsidicom_file_target.py index 052af60d..5884009a 100644 --- a/wsidicom/file/wsidicom_file_target.py +++ b/wsidicom/file/wsidicom_file_target.py @@ -33,8 +33,10 @@ from pydicom.valuerep import MAX_VALUE_LEN from upath import UPath +from wsidicom.cache import DecodedFrameCache, EncodedFrameCache from wsidicom.codec import Encoder from wsidicom.codec import Settings as EncoderSettings +from wsidicom.config import settings from wsidicom.file.io import ( OffsetTableType, WsiDicomReader, @@ -101,6 +103,8 @@ def __init__( self._filepaths: List[UPath] = [] self._opened_files: List[WsiDicomReader] = [] self._file_options = file_options + self._decoded_frame_cache = DecodedFrameCache(settings.decoded_frame_cache_size) + self._encoded_frame_cache = EncodedFrameCache(settings.encoded_frame_cache_size) super().__init__( uid_generator, workers, @@ -273,7 +277,10 @@ def _open_files(self, filepaths: Iterable[UPath]) -> List[WsiInstance]: self._opened_files.extend(readers) return [ WsiInstance( - [reader.dataset for reader in readers], WsiDicomFileImageData(readers) + [reader.dataset for reader in readers], + WsiDicomFileImageData( + readers, self._decoded_frame_cache, self._encoded_frame_cache + ), ) ] diff --git a/wsidicom/instance/image_data.py b/wsidicom/instance/image_data.py index a11c2cb3..ac769cd7 100644 --- a/wsidicom/instance/image_data.py +++ b/wsidicom/instance/image_data.py @@ -250,28 +250,6 @@ def blank_encoded_tile(self) -> bytes: def encoder(self) -> Encoder: return self._encoder - def get_decoded_tiles( - self, tiles: Iterable[Point], z: float, path: str - ) -> Iterator[Image]: - """ - Return Pillow images for tiles. - - Parameters - ---------- - tiles: Iterable[Point] - Tiles to get. - z: float - Z coordinate. - path: str - Optical path. - - Returns - ------- - Iterator[Image] - Tiles as Images. - """ - return self._get_decoded_tiles(tiles, z, path) - def get_encoded_tiles( self, tiles: Iterable[Point], @@ -392,7 +370,7 @@ def get_tile(self, tile_point: Point, z: float, path: str) -> Image: def get_tiles(self, tiles: Iterable[Point], z: float, path: str) -> Iterator[Image]: return ( self._crop_tile(tile_point, tile) - for tile_point, tile in zip(tiles, self.get_decoded_tiles(tiles, z, path)) + for tile_point, tile in zip(tiles, self._get_decoded_tiles(tiles, z, path)) ) def get_scaled_encoded_tile( @@ -446,7 +424,7 @@ def get_encoded_tile(self, tile: Point, z: float, path: str) -> bytes: if cropped_tile_region.size == self.tile_size: return self._get_encoded_tile(tile, z, path) image = self._get_decoded_tile(tile, z, path) - image.crop(box=cropped_tile_region.box_from_origin) + image = image.crop(box=cropped_tile_region.box_from_origin) return self.encoder.encode(image) def stitch_tiles(self, region: Region, path: str, z: float, threads: int) -> Image: @@ -579,7 +557,7 @@ def _crop_tile(self, tile_point: Point, tile: Image) -> Image: # Check if tile is an edge tile that should be cropped if tile_crop.size != self.tile_size: return tile.crop(box=tile_crop.box) - return tile + return tile.copy() def _paste_tiles( self, @@ -616,7 +594,7 @@ def _paste_tiles( def thread_paste(tile_points: Iterable[Point]) -> None: tile_points = list(tile_points) for tile_point, tile in zip( - tile_points, self.get_decoded_tiles(tile_points, z, path) + tile_points, self._get_decoded_tiles(tile_points, z, path) ): paste_method(image, tile_point, tile) diff --git a/wsidicom/instance/wsidicom_image_data.py b/wsidicom/instance/wsidicom_image_data.py index 34e5b4f5..3112ed4f 100644 --- a/wsidicom/instance/wsidicom_image_data.py +++ b/wsidicom/instance/wsidicom_image_data.py @@ -18,6 +18,7 @@ from PIL.Image import Image +from wsidicom.cache import DecodedFrameCache, EncodedFrameCache from wsidicom.codec import Codec, Decoder from wsidicom.errors import WsiDicomOutOfBoundsError from wsidicom.geometry import Point, Region, Size, SizeMm @@ -26,14 +27,22 @@ from wsidicom.instance.tile_index.full_tile_index import FullTileIndex from wsidicom.instance.tile_index.sparse_tile_index import SparseTileIndex from wsidicom.instance.tile_index.tile_index import TileIndex -from wsidicom.metadata.schema.dicom.image import ImageCoordinateSystemDicomSchema from wsidicom.metadata.image import ImageCoordinateSystem +from wsidicom.metadata.schema.dicom.image import ImageCoordinateSystemDicomSchema class WsiDicomImageData(ImageData, metaclass=ABCMeta): - def __init__(self, datasets: Sequence[WsiDataset], codec: Codec): + def __init__( + self, + datasets: Sequence[WsiDataset], + codec: Codec, + decoded_frame_cache: DecodedFrameCache, + encoded_frame_cache: EncodedFrameCache, + ): self._datasets = datasets self._decoder = codec.decoder + self._decoded_frame_cache = decoded_frame_cache + self._encoded_frame_cache = encoded_frame_cache super().__init__(codec.encoder) @abstractmethod @@ -55,6 +64,9 @@ def _get_tile_frame(self, frame_index: int) -> bytes: def _get_tile_frames(self, frame_indices: Sequence[int]) -> Iterator[bytes]: return (self._get_tile_frame(frame_index) for frame_index in frame_indices) + def _get_decoded_tile_frame(self, frame_index: int) -> Image: + return self.decoder.decode(self._get_tile_frame(frame_index)) + @cached_property def tiles(self) -> TileIndex: """Return tile index for image.""" @@ -141,7 +153,10 @@ def _get_encoded_tile(self, tile: Point, z: float, path: str) -> bytes: frame_index = self._get_frame_index(tile, z, path) if frame_index == -1: return self.blank_encoded_tile - return self._get_tile_frame(frame_index) + + return self._encoded_frame_cache.get_tile_frame( + id(self), frame_index, self._get_tile_frame + ) def _get_decoded_tile(self, tile: Point, z: float, path: str) -> Image: """ @@ -164,8 +179,9 @@ def _get_decoded_tile(self, tile: Point, z: float, path: str) -> Image: frame_index = self._get_frame_index(tile, z, path) if frame_index == -1: return self.blank_tile - frame = self._get_tile_frame(frame_index) - return self.decoder.decode(frame) + return self._decoded_frame_cache.get_tile_frame( + id(self), frame_index, self._get_decoded_tile_frame + ) def _get_frame_index(self, tile: Point, z: float, path: str) -> int: """ diff --git a/wsidicom/source.py b/wsidicom/source.py index d08c7707..c879f508 100644 --- a/wsidicom/source.py +++ b/wsidicom/source.py @@ -15,6 +15,8 @@ from abc import ABCMeta, abstractmethod from typing import Iterable +from wsidicom.cache import DecodedFrameCache, EncodedFrameCache +from wsidicom.config import settings from wsidicom.graphical_annotations import AnnotationInstance from wsidicom.instance import WsiDataset, WsiInstance @@ -30,6 +32,10 @@ class Source(metaclass=ABCMeta): instances. """ + def __init__(self): + self._decoded_frame_cache = DecodedFrameCache(settings.decoded_frame_cache_size) + self._encoded_frame_cache = EncodedFrameCache(settings.encoded_frame_cache_size) + def __enter__(self): return self @@ -70,3 +76,13 @@ def annotation_instances(self) -> Iterable[AnnotationInstance]: def close(self) -> None: """Close any opened resources (such as files).""" raise NotImplementedError() + + def clear_cache(self) -> None: + """Clear the frame caches.""" + self._decoded_frame_cache.clear() + self._encoded_frame_cache.clear() + + def resize_cache(self, size: int) -> None: + """Clear the frame caches.""" + self._decoded_frame_cache.resize(size) + self._encoded_frame_cache.resize(size) diff --git a/wsidicom/web/wsidicom_web_image_data.py b/wsidicom/web/wsidicom_web_image_data.py index 16c9c9bb..d1b114ee 100644 --- a/wsidicom/web/wsidicom_web_image_data.py +++ b/wsidicom/web/wsidicom_web_image_data.py @@ -17,6 +17,7 @@ from PIL.Image import Image from pydicom.uid import UID +from wsidicom.cache import DecodedFrameCache, EncodedFrameCache from wsidicom.codec import Codec from wsidicom.geometry import Point from wsidicom.instance import WsiDataset, WsiDicomImageData @@ -35,6 +36,8 @@ def __init__( client: WsiDicomWebClient, dataset: WsiDataset, transfer_syntax: UID, + decoded_frame_cache: DecodedFrameCache, + encoded_frame_cache: EncodedFrameCache, ): """ Create WsiDicomWebImageData from dataset and provided client. @@ -61,7 +64,7 @@ def __init__( dataset.tile_size, dataset.photometric_interpretation, ) - super().__init__([dataset], codec) + super().__init__([dataset], codec, decoded_frame_cache, encoded_frame_cache) @property def transfer_syntax(self) -> UID: @@ -89,15 +92,17 @@ def _get_decoded_tiles( Tiles as Images. """ frame_indices = [self._get_frame_index(tile, z, path) for tile in tiles] - - frames = self._get_tile_frames( - frame_index for frame_index in frame_indices if frame_index != -1 + frames = self._decoded_frame_cache.get_tile_frames( + id(self), + [frame_index for frame_index in frame_indices if frame_index != -1], + self._get_decoded_tile_frames, ) + for frame_index in frame_indices: if frame_index == -1: yield self.blank_tile - frame = next(frames) - yield self.decoder.decode(frame) + else: + yield next(frames) def _get_encoded_tiles( self, tiles: Iterable[Point], z: float, path: str @@ -120,14 +125,16 @@ def _get_encoded_tiles( Tiles as Images. """ frame_indices = [self._get_frame_index(tile, z, path) for tile in tiles] - - frames = self._get_tile_frames( - frame_index for frame_index in frame_indices if frame_index != -1 + frames = self._encoded_frame_cache.get_tile_frames( + id(self), + [frame_index for frame_index in frame_indices if frame_index != -1], + self._get_tile_frames, ) for frame_index in frame_indices: if frame_index == -1: yield self.blank_encoded_tile - return next(frames) + else: + yield next(frames) def _get_tile_frame(self, frame_index: int) -> bytes: # First frame for DICOM web is 1. @@ -150,3 +157,7 @@ def _get_tile_frames(self, frame_indices: Iterable[int]) -> Iterator[bytes]: [frame_index + 1 for frame_index in frame_indices], self._transfer_syntax, ) + + def _get_decoded_tile_frames(self, frame_indices: Iterable[int]) -> Iterator[Image]: + for frame in self._get_tile_frames(frame_indices): + yield self.decoder.decode(frame) diff --git a/wsidicom/web/wsidicom_web_source.py b/wsidicom/web/wsidicom_web_source.py index a0d848cf..5a41ab10 100644 --- a/wsidicom/web/wsidicom_web_source.py +++ b/wsidicom/web/wsidicom_web_source.py @@ -80,7 +80,7 @@ def __init__( UID("1.2.840.10008.1.2.4.50") for JPEGBaseline8Bit. """ - + super().__init__() self._level_instances: List[WsiInstance] = [] self._label_instances: List[WsiInstance] = [] self._overview_instances: List[WsiInstance] = [] @@ -114,7 +114,13 @@ def create_instance( detected_transfer_syntaxes_by_image_type[dataset.image_type].add( transfer_syntax ) - image_data = WsiDicomWebImageData(client, dataset, transfer_syntax) + image_data = WsiDicomWebImageData( + client, + dataset, + transfer_syntax, + self._decoded_frame_cache, + self._encoded_frame_cache, + ) return WsiInstance(dataset, image_data) instance_uids = ( diff --git a/wsidicom/wsidicom.py b/wsidicom/wsidicom.py index 6e1f5222..322da25b 100644 --- a/wsidicom/wsidicom.py +++ b/wsidicom/wsidicom.py @@ -815,6 +815,14 @@ def close(self) -> None: if self._source_owned: self._source.close() + def clear_cache(self): + """Clear cache of encoded and decoded tiles.""" + self._source.clear_cache() + + def resize_cache(self, size: int): + """Resize cache of encoded and decoded tiles.""" + self._source.resize_cache(size) + def _validate_collection(self) -> SlideUids: """ Check that no files or instance in collection is duplicate, and, if