Skip to content

Commit

Permalink
New tiling method for inference on large images
Browse files Browse the repository at this point in the history
  • Loading branch information
jommarin committed Jun 24, 2024
1 parent a8835a1 commit 7eecf62
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 3 deletions.
64 changes: 61 additions & 3 deletions deepliif/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ def get_net_tiles(n):
return images


def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
color_dapi=False, color_marker=False, opt=None):
def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
color_dapi=False, color_marker=False, opt=None):
if not opt:
opt = get_opt(model_path)
#print_options(opt)
Expand Down Expand Up @@ -489,6 +489,63 @@ def get_net_tiles(n):
raise Exception(f'inference() not implemented for model {opt.model}')


def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
eager_mode=False, color_dapi=False, color_marker=False, opt=None):
if not opt:
opt = get_opt(model_path)
#print_options(opt)

run_fn = run_torchserve if use_torchserve else run_dask

if opt.model == 'SDG':
# SDG could have multiple input images/modalities, hence the input could be a rectangle.
# We split the input to get each modality image then create tiles for each set of input images.
w, h = int(img.width / opt.input_no), img.height
orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
else:
# Otherwise expect a single input image, which is used directly.
orig = img

tiler = InferenceTiler(orig, tile_size, overlap_size)
for tile in tiler:
tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt))
results = tiler.results()

if opt.model == 'DeepLIIF':
images = {
'Hema': results['G1'],
'DAPI': results['G2'],
'Lap2': results['G3'],
'Marker': results['G4'],
'Seg': results['G5'],
}
if color_dapi:
matrix = ( 0, 0, 0, 0,
299/1000, 587/1000, 114/1000, 0,
299/1000, 587/1000, 114/1000, 0)
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
if color_marker:
matrix = (299/1000, 587/1000, 114/1000, 0,
299/1000, 587/1000, 114/1000, 0,
0, 0, 0, 0)
images['Marker'] = images['Marker'].convert('RGB', matrix)
return images

elif opt.model == 'DeepLIIFExt':
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
if opt.seg_gen:
images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
return images

elif opt.model == 'SDG':
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
return images

else:
#raise Exception(f'inference() not implemented for model {opt.model}')
return results # return result images with default key names (i.e., net names)


def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
if model == 'DeepLIIF':
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
Expand Down Expand Up @@ -546,7 +603,8 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
images = inference(
img,
tile_size=tile_size,
overlap_size=compute_overlap(img_size, tile_size),
#overlap_size=compute_overlap(img_size, tile_size),
overlap_size=tile_size//16,
model_path=model_dir,
eager_mode=eager_mode,
color_dapi=color_dapi,
Expand Down
205 changes: 205 additions & 0 deletions deepliif/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,211 @@ def stitch_tile(img, tile, tile_size, overlap_size, i, j):
img.paste(tile, (i * tile_size, j * tile_size))


class InferenceTiler:
"""
Iterable class to tile image(s) and stitch result tiles together.
To perform inference on a large image, that image will need to be
tiled into smaller tiles that can be run individually and then
stitched back together. This class wraps the functionality as an
iterable object that can accept a single image or list of images
if multiple images are taken as input for inference.
An overlap size can be specified so that neighboring tiles will
overlap at the edges, helping to reduce seams or other artifacts
near the edge of a tile. Padding of a solid color around the
perimeter of the tile is also possible, if needed. The specified
tile size includes this overlap and pad sizes, so a tile size of
512 with an overlap size of 32 and pad size of 16 would have a
central area of 416 pixels that are stitched into the result image.
Example Usage
-------------
tiler = InferenceTiler(img, 512, 32)
for tile in tiler:
result_tiles = infer(tile)
tiler.stitch(result_tiles)
images = tiler.results()
"""

def __init__(self, orig, tile_size, overlap_size=0, pad_size=0, pad_color=(255, 255, 255)):
"""
Initialize for tiling an image or list of images.
Parameters
----------
orig : Image | list(Image)
Original image or list of images to be tiled.
tile_size: int
Size (width and height) of the tiles to be generated.
overlap_size: int [default: 0]
Amount of overlap on each side of the tile.
pad_size: int [default: 0]
Amount of solid color padding around perimeter of tile.
pad_color: tuple(int, int, int) [default: (255,255,255)]
RGB color to use for padding.
"""

if tile_size <= 0:
raise ValueError('InfereneTiler input tile_size must be positive and non-zero')
if overlap_size < 0:
raise ValueError('InfereneTiler input overlap_size must be positive or zero')
if pad_size < 0:
raise ValueError('InfereneTiler input pad_size must be positive or zero')

self.single_orig = not type(orig) is list
if self.single_orig:
orig = [orig]

for i in range(1, len(orig)):
if orig[i].size != orig[0].size:
raise ValueError('InferenceTiler input images do not have the same size.')
self.orig_width = orig[0].width
self.orig_height = orig[0].height

# patch size to extract from input image, which is then padded to tile size
patch_size = tile_size - (2 * pad_size)

# make sure width and height are both at least patch_size
if orig[0].width < patch_size:
for i in range(len(orig)):
while orig[i].width < patch_size:
mirrored = ImageOps.mirror(orig[i])
orig[i] = ImageOps.expand(orig[i], (0, 0, orig[i].width, 0))
orig[i].paste(mirrored, (mirrored.width, 0))
orig[i] = orig[i].crop((0, 0, patch_size, orig[i].height))
if orig[0].height < patch_size:
for i in range(len(orig)):
while orig[i].height < patch_size:
flipped = ImageOps.flip(orig[i])
orig[i] = ImageOps.expand(orig[i], (0, 0, 0, orig[i].height))
orig[i].paste(flipped, (0, flipped.height))
orig[i] = orig[i].crop((0, 0, orig[i].width, patch_size))
self.image_width = orig[0].width
self.image_height = orig[0].height

overlap_width = 0 if patch_size >= self.image_width else overlap_size
overlap_height = 0 if patch_size >= self.image_height else overlap_size
center_width = patch_size - (2 * overlap_width)
center_height = patch_size - (2 * overlap_height)
if center_width <= 0 or center_height <= 0:
raise ValueError('InferenceTiler combined overlap_size and pad_size are too large')

self.c0x = pad_size # crop offset for left of non-pad content in result tile
self.c0y = pad_size # crop offset for top of non-pad content in result tile
self.c1x = overlap_width + pad_size # crop offset for left of center region in result tile
self.c1y = overlap_height + pad_size # crop offset for top of center region in result tile
self.c2x = patch_size - overlap_width + pad_size # crop offset for right of center region in result tile
self.c2y = patch_size - overlap_height + pad_size # crop offset for bottom of center region in result tile
self.c3x = patch_size + pad_size # crop offset for right of non-pad content in result tile
self.c3y = patch_size + pad_size # crop offset for bottom of non-pad content in result tile
self.p1x = overlap_width # paste offset for left of center region w.r.t (x,y) coord
self.p1y = overlap_height # paste offset for top of center region w.r.t (x,y) coord
self.p2x = patch_size - overlap_width # paste offset for right of center region w.r.t (x,y) coord
self.p2y = patch_size - overlap_height # paste offset for bottom of center region w.r.t (x,y) coord

self.overlap_width = overlap_width
self.overlap_height = overlap_height
self.patch_size = patch_size
self.center_width = center_width
self.center_height = center_height

self.orig = orig
self.tile_size = tile_size
self.pad_size = pad_size
self.pad_color = pad_color
self.res = {}

def __iter__(self):
"""
Generate the tiles as an iterable.
Tiles are created and iterated over from top left to bottom
right, going across the rows. The yielded tile(s) match the
type of the original input when initialized (either a single
image or a list of images in the same order as initialized).
The (x, y) coordinate of the current tile is maintained
internally for use in the stitch function.
"""

for y in range(0, self.image_height, self.center_height):
for x in range(0, self.image_width, self.center_width):
if x + self.patch_size > self.image_width:
x = self.image_width - self.patch_size
if y + self.patch_size > self.image_height:
y = self.image_height - self.patch_size
self.x = x
self.y = y
tiles = [im.crop((x, y, x + self.patch_size, y + self.patch_size)) for im in self.orig]
if self.pad_size != 0:
tiles = [ImageOps.expand(t, self.pad_size, self.pad_color) for t in tiles]
yield tiles[0] if self.single_orig else tiles

def stitch(self, result_tiles):
"""
Stitch result tiles into the result images.
The key names for the dictionary of result tiles are used to
stitch each tile into its corresponding final image in the
results attribute. If a result image does not exist for a
result tile key name, then it will be created. The result tiles
are stitched at the location from which the list iterated tile
was extracted.
Parameters
----------
result_tiles : dict(str: Image)
Dictionary of result tiles from the inference.
"""

for k, tile in result_tiles.items():
if k not in self.res:
self.res[k] = Image.new('RGB', (self.image_width, self.image_height))
if tile.size != (self.tile_size, self.tile_size):
tile = tile.resize((self.tile_size, self.tile_size))
self.res[k].paste(tile.crop((self.c1x, self.c1y, self.c2x, self.c2y)), (self.x + self.p1x, self.y + self.p1y))

# top left corner
if self.x == 0 and self.y == 0:
self.res[k].paste(tile.crop((self.c0x, self.c0y, self.c1x, self.c1y)), (self.x, self.y))
# top row
if self.y == 0:
self.res[k].paste(tile.crop((self.c1x, self.c0y, self.c2x, self.c1y)), (self.x + self.p1x, self.y))
# top right corner
if self.x == self.image_width - self.patch_size and self.y == 0:
self.res[k].paste(tile.crop((self.c2x, self.c0y, self.c3x, self.c1y)), (self.x + self.p2x, self.y))
# left column
if self.x == 0:
self.res[k].paste(tile.crop((self.c0x, self.c1y, self.c1x, self.c2y)), (self.x, self.y + self.p1y))
# right column
if self.x == self.image_width - self.patch_size:
self.res[k].paste(tile.crop((self.c2x, self.c1y, self.c3x, self.c2y)), (self.x + self.p2x, self.y + self.p1y))
# bottom left corner
if self.x == 0 and self.y == self.image_height - self.patch_size:
self.res[k].paste(tile.crop((self.c0x, self.c2y, self.c1x, self.c3y)), (self.x, self.y + self.p2y))
# bottom row
if self.y == self.image_height - self.patch_size:
self.res[k].paste(tile.crop((self.c1x, self.c2y, self.c2x, self.c3y)), (self.x + self.p1x, self.y + self.p2y))
# bottom right corner
if self.x == self.image_width - self.patch_size and self.y == self.image_height - self.patch_size:
self.res[k].paste(tile.crop((self.c2x, self.c2y, self.c3x, self.c3y)), (self.x + self.p2x, self.y + self.p2y))

def results(self):
"""
Return a dictionary of result images.
The keys for the result images are the same as those used for
the result tiles in the stitch function. This function should
only be called once, since the stitched images will be cropped
if the original image size was less than the patch size.
"""

if self.orig_width != self.image_width or self.orig_height != self.image_height:
return {k: im.crop((0, 0, self.orig_width, self.orig_height)) for k, im in self.res.items()}
else:
return {k: im for k, im in self.res.items()}


def calculate_background_mean_value(img):
img = cv2.fastNlMeansDenoisingColored(np.array(img), None, 10, 10, 7, 21)
img = np.array(img, dtype=float)
Expand Down

0 comments on commit 7eecf62

Please sign in to comment.