Skip to content

Commit

Permalink
🔀 Merge changes from #578
Browse files Browse the repository at this point in the history
- Merge changes from #578 which will improve the performance of main branch.
- It will also help simplify #578
  • Loading branch information
shaneahmed committed Nov 21, 2024
1 parent 32cae0b commit 7532787
Show file tree
Hide file tree
Showing 29 changed files with 173 additions and 142 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ ENV/

# vim/vi generated
*.swp

# output zarr generated
*.zarr
4 changes: 2 additions & 2 deletions tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_functionality(remote_sample: Callable) -> None:
model = _load_mapde(name="mapde-conic")
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
model = model.to(select_device(on_gpu=ON_GPU))
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
model = model.to()
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
output = model.postproc(output[0])
assert np.all(output[0:2] == [[19, 171], [53, 89]])
2 changes: 1 addition & 1 deletion tests/models/test_arch_micronet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_functionality(
model = model.to(map_location)
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=map_location)
output, _ = model.postproc(output[0])
assert np.max(np.unique(output)) == 46

Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_arch_nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tiatoolbox.models import NuClick
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device

ON_GPU = False

Expand Down Expand Up @@ -53,7 +54,7 @@ def test_functional_nuclick(
model = NuClick(num_input_channels=5, num_output_channels=1)
pretrained = torch.load(weights_path, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
postproc_masks = model.postproc(
output,
do_reconstruction=True,
Expand Down
17 changes: 13 additions & 4 deletions tests/models/test_arch_sccnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import numpy as np
import torch

from tiatoolbox import utils
from tiatoolbox.models import SCCNN
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import env_detection
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader


def _load_sccnn(name: str) -> torch.nn.Module:
"""Loads SCCNN model with specified weights."""
model = SCCNN()
weights_path = fetch_pretrained_weights(name)
map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu())
map_location = select_device(on_gpu=env_detection.has_gpu())
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)

Expand All @@ -40,11 +41,19 @@ def test_functionality(remote_sample: Callable) -> None:
)
batch = torch.from_numpy(patch)[None]
model = _load_sccnn(name="sccnn-crchisto")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
assert np.all(output == [[8, 7]])

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
assert np.all(output == [[7, 8]])
5 changes: 3 additions & 2 deletions tests/models/test_arch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.unet import UNetModel
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = False
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3])
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
_ = output[0]

# run untrained network to test for architecture
Expand All @@ -60,4 +61,4 @@ def test_functional_unet(remote_sample: Callable) -> None:
encoder_levels=[32, 64],
skip_type="concat",
)
_ = model.infer_batch(model, batch, on_gpu=ON_GPU)
_ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
11 changes: 6 additions & 5 deletions tests/models/test_arch_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import torch

from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
from tiatoolbox.utils.misc import model_to
from tiatoolbox.models.models_abc import model_to

ON_GPU = False
RNG = np.random.default_rng() # Numpy Random Generator
device = "cuda" if ON_GPU else "cpu"


def test_functional() -> None:
Expand Down Expand Up @@ -43,8 +44,8 @@ def test_functional() -> None:
try:
for backbone in backbones:
model = CNNModel(backbone, num_classes=1)
model_ = model_to(on_gpu=ON_GPU, model=model)
model.infer_batch(model_, samples, on_gpu=ON_GPU)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand All @@ -70,8 +71,8 @@ def test_timm_functional() -> None:
try:
for backbone in backbones:
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
model_ = model_to(on_gpu=ON_GPU, model=model)
model.infer_batch(model_, samples, on_gpu=ON_GPU)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand Down
9 changes: 5 additions & 4 deletions tests/models/test_hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ResidualBlock,
TFSamepaddingLayer,
)
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader


Expand All @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_fast-pannuke")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_fast-monusac")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_original-consep")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_original-kumar")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_hovernetplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tiatoolbox.models import HoVerNetPlus
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device
from tiatoolbox.utils.transforms import imresize


Expand All @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernetplus-oed")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches."
output = [v[0] for v in output]
output = model.postproc(output)
Expand Down
11 changes: 3 additions & 8 deletions tests/test_annotation_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@
FILLED_LEN = 2 * (GRID_SIZE[0] * GRID_SIZE[1])
RNG = np.random.default_rng(0) # Numpy Random Generator

# ----------------------------------------------------------------------
# Resets
# ----------------------------------------------------------------------

# Reset filters in logger.
for filter_ in logger.filters:
logger.removeFilter(filter_)

# ----------------------------------------------------------------------
# Helper Functions
# ----------------------------------------------------------------------
Expand Down Expand Up @@ -546,6 +538,9 @@ def test_sqlite_store_compile_options_missing_math(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that a warning is shown if the sqlite math module is missing."""
# Reset filters in logger.
for filter_ in logger.filters[:]:
logger.removeFilter(filter_)
monkeypatch.setattr(
SQLiteStore,
"compile_options",
Expand Down
1 change: 1 addition & 0 deletions tests/test_annotation_tilerendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def test_function_mapper(fill_store: Callable, tmp_path: Path) -> None:
_, store = fill_store(SQLiteStore, tmp_path / "test.db")

def color_fn(props: dict[str, str]) -> tuple[int, int, int]:
"""Tests Red for cells, otherwise green."""
# simple test function that returns red for cells, otherwise green.
if props["type"] == "cell":
return 1, 0, 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_duplicate_filter(caplog: pytest.LogCaptureFixture) -> None:
logger.addFilter(duplicate_filter)

# Reset filters in logger.
for filter_ in logger.filters:
for filter_ in logger.filters[:]:
logger.removeFilter(filter_)

for _ in range(2):
Expand Down
21 changes: 2 additions & 19 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,24 +1336,6 @@ def test_select_device() -> None:
assert device == "cpu"


def test_model_to() -> None:
"""Test for placing model on device."""
import torchvision.models as torch_models
from torch import nn

# Test on GPU
# no GPU on Travis so this will crash
if not utils.env_detection.has_gpu():
model = torch_models.resnet18()
with pytest.raises((AssertionError, RuntimeError)):
_ = misc.model_to(on_gpu=True, model=model)

# Test on CPU
model = torch_models.resnet18()
model = misc.model_to(on_gpu=False, model=model)
assert isinstance(model, nn.Module)


def test_save_as_json(tmp_path: Path) -> None:
"""Test save data to json."""
# This should be broken up into separate tests!
Expand Down Expand Up @@ -1666,6 +1648,7 @@ def test_patch_pred_store() -> None:
"""Test patch_pred_store."""
# Define a mock patch_output
patch_output = {
"probabilities": [(0.99, 0.01), (0.01, 0.99), (0.99, 0.01)],
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"other": "other",
Expand Down Expand Up @@ -1700,7 +1683,7 @@ def test_patch_pred_store_cdict() -> None:
class_dict = {0: "class0", 1: "class1"}
store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict)

# Check that its an SQLiteStore containing the expected annotations
# Check that it is an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
Expand Down
1 change: 0 additions & 1 deletion tests/test_wsimeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tiatoolbox.wsicore import WSIMeta, wsimeta, wsireader


# noinspection PyTypeChecker
def test_wsimeta_init_fail() -> None:
"""Test incorrect init for WSIMeta raises TypeError."""
with pytest.raises(TypeError):
Expand Down
16 changes: 15 additions & 1 deletion tiatoolbox/annotation/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2556,7 +2556,21 @@ def _unpack_wkb(
cx: float,
cy: float,
) -> bytes:
"""Unpack WKB data."""
"""Return the geometry as bytes using WKB.
Args:
data (bytes or str):
The WKB/WKT data to be unpacked.
cx (int):
The X coordinate of the centroid/representative point.
cy (float):
The Y coordinate of the centroid/representative point.
Returns:
bytes:
The geometry as bytes.
"""
return (
self._decompress_data(data)
if data
Expand Down
12 changes: 12 additions & 0 deletions tiatoolbox/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ def cli_pretrained_weights(
)


def cli_device(
usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.",
default: str = "cpu",
) -> Callable:
"""Enables --pretrained-weights option for cli."""
return click.option(
"--device",
help=add_default_to_usage_help(usage_help, default),
default=default,
)


def cli_return_probabilities(
usage_help: str = "Whether to return raw model probabilities.",
*,
Expand Down
8 changes: 4 additions & 4 deletions tiatoolbox/cli/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from tiatoolbox.cli.common import (
cli_batch_size,
cli_device,
cli_file_type,
cli_img_input,
cli_masks,
cli_merge_predictions,
cli_mode,
cli_num_loader_workers,
cli_on_gpu,
cli_output_path,
cli_pretrained_model,
cli_pretrained_weights,
Expand Down Expand Up @@ -45,7 +45,7 @@
@cli_return_probabilities(default=False)
@cli_merge_predictions(default=True)
@cli_return_labels(default=True)
@cli_on_gpu(default=False)
@cli_device(default="cpu")
@cli_batch_size(default=1)
@cli_resolution(default=0.5)
@cli_units(default="mpp")
Expand All @@ -64,11 +64,11 @@ def patch_predictor(
resolution: float,
units: str,
num_loader_workers: int,
device: str,
*,
return_probabilities: bool,
return_labels: bool,
merge_predictions: bool,
on_gpu: bool,
verbose: bool,
) -> None:
"""Process an image/directory of input images with a patch classification CNN."""
Expand Down Expand Up @@ -100,7 +100,7 @@ def patch_predictor(
return_labels=return_labels,
resolution=resolution,
units=units,
on_gpu=on_gpu,
device=device,
save_dir=output_path,
save_output=True,
)
Expand Down
Loading

0 comments on commit 7532787

Please sign in to comment.