-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for Zarr stitching/fusion #41
Open
jwong-nd
wants to merge
7
commits into
google-research:main
Choose a base branch
from
jwong-nd:feat-zarr-processor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
a760a5b
Adding support for Zarr stitching/fusion
jwong-nd 05e33da
revert core sofima changes, export feat into zarr/
jwong-nd 9fe84c7
env setup, temporary files
jwong-nd 09fad0b
tested changes, removed tmp comments/files
jwong-nd cabcaf0
add zarr package to setup.cfg
jwong-nd 4c03e41
remove relative import
jwong-nd 96f7eda
resolve field error
jwong-nd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
__pycache__ | ||
dist | ||
sofima.egg-info | ||
_version.py | ||
*.npz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import functools as ft | ||
import gc | ||
import jax | ||
import numpy as np | ||
import tensorstore as ts | ||
|
||
from sofima import flow_field | ||
|
||
|
||
QUERY_R_ORTHO = 100 | ||
QUERY_OVERLAP_OFFSET = 0 # Overlap = 'starting line' in neighboring tile | ||
QUERY_R_OVERLAP = 100 | ||
|
||
SEARCH_OVERLAP = 300 # Boundary - overlap = 'starting line' in search tile | ||
SEARCH_R_ORTHO = 100 | ||
|
||
|
||
@ft.partial(jax.jit) | ||
def _estimate_relative_offset_zyx(base, | ||
kernel | ||
) -> list[float, float, float]: | ||
# Calculate FFT: left = base, right = kernel | ||
xc = flow_field.masked_xcorr(base, kernel, use_jax=True, dim=3) | ||
xc = xc.astype(np.float32) | ||
xc = xc[None, ...] | ||
|
||
# Find strongest peak in FFT, pass in FFT image center | ||
r = flow_field._batched_peaks(xc, | ||
((xc.shape[1] + 1) // 2, (xc.shape[2] + 1) // 2, xc.shape[3] // 2), | ||
min_distance=2, | ||
threshold_rel=0.5) | ||
|
||
# r returns a list, relative offset is here | ||
relative_offset_xyz = r[0][0:3] | ||
return [relative_offset_xyz[2], relative_offset_xyz[1], relative_offset_xyz[0]] | ||
|
||
|
||
def _estimate_h_offset_zyx(left_tile: ts.TensorStore, | ||
right_tile: ts.TensorStore | ||
) -> tuple[list[float], float]: | ||
tile_size_xyz = left_tile.shape | ||
mz = tile_size_xyz[2] // 2 | ||
my = tile_size_xyz[1] // 2 | ||
|
||
# Search Space, fixed | ||
left = left_tile[tile_size_xyz[0]-SEARCH_OVERLAP:, | ||
my-SEARCH_R_ORTHO:my+SEARCH_R_ORTHO, | ||
mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T | ||
|
||
# Query Patch, scanned against search space | ||
right = right_tile[QUERY_OVERLAP_OFFSET:QUERY_OVERLAP_OFFSET + QUERY_R_OVERLAP*2, | ||
my-QUERY_R_ORTHO:my+QUERY_R_ORTHO, | ||
mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T | ||
|
||
start_zyx = np.array(left.shape) // 2 - np.array(right.shape) // 2 | ||
pc_init_zyx = np.array([0, 0, tile_size_xyz[0] - SEARCH_OVERLAP + start_zyx[2]]) | ||
pc_zyx = np.array(_estimate_relative_offset_zyx(left, right)) | ||
|
||
return pc_init_zyx + pc_zyx | ||
|
||
|
||
def _estimate_v_offset_zyx(top_tile: ts.TensorStore, | ||
bot_tile: ts.TensorStore, | ||
) -> tuple[list[float], float]: | ||
tile_size_xyz = top_tile.shape | ||
mz = tile_size_xyz[2] // 2 | ||
mx = tile_size_xyz[0] // 2 | ||
|
||
top = top_tile[mx-SEARCH_R_ORTHO:mx+SEARCH_R_ORTHO, | ||
tile_size_xyz[1]-SEARCH_OVERLAP:, | ||
mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T | ||
bot = bot_tile[mx-QUERY_R_ORTHO:mx+QUERY_R_ORTHO, | ||
0:QUERY_R_OVERLAP*2, | ||
mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T | ||
|
||
start_zyx = np.array(top.shape) // 2 - np.array(bot.shape) // 2 | ||
pc_init_zyx = np.array([0, tile_size_xyz[1] - SEARCH_OVERLAP + start_zyx[1], 0]) | ||
pc_zyx = np.array(_estimate_relative_offset_zyx(top, bot)) | ||
|
||
return pc_init_zyx + pc_zyx | ||
|
||
|
||
def compute_coarse_offsets(tile_layout: np.ndarray, | ||
tile_volumes: list[ts.TensorStore] | ||
) -> tuple[np.ndarray, np.ndarray]: | ||
layout_y, layout_x = tile_layout.shape | ||
|
||
# Output Containers, sofima uses cartesian convention | ||
conn_x = np.full((3, 1, layout_y, layout_x), np.nan) | ||
conn_y = np.full((3, 1, layout_y, layout_x), np.nan) | ||
|
||
# Row Pairs | ||
for y in range(layout_y): | ||
for x in range(layout_x - 1): # Stop one before the end | ||
left_id = tile_layout[y, x] | ||
right_id = tile_layout[y, x + 1] | ||
left_tile = tile_volumes[left_id] | ||
right_tile = tile_volumes[right_id] | ||
|
||
conn_x[:, 0, y, x] = _estimate_h_offset_zyx(left_tile, right_tile) | ||
gc.collect() | ||
|
||
print(f'Left Id: {left_id}, Right Id: {right_id}') | ||
print(f'Left: ({y}, {x}), Right: ({y}, {x + 1})', conn_x[:, 0, y, x]) | ||
|
||
# Column Pairs -- Reversed Loops | ||
for x in range(layout_x): | ||
for y in range(layout_y - 1): | ||
top_id = tile_layout[y, x] | ||
bot_id = tile_layout[y + 1, x] | ||
top_tile = tile_volumes[top_id] | ||
bot_tile = tile_volumes[bot_id] | ||
|
||
conn_y[:, 0, y, x] = _estimate_v_offset_zyx(top_tile, bot_tile) | ||
gc.collect() | ||
|
||
print(f'Top Id: {top_id}, Bottom Id: {bot_id}') | ||
print(f'Top: ({y}, {x}), Bot: ({y + 1}, {x})', conn_y[:, 0, y, x]) | ||
|
||
return conn_x, conn_y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/bin/bash | ||
conda create --name py311 -c conda-forge python=3.11 -y | ||
conda run -n py311 pip install git+https://github.com/google-research/sofima | ||
conda run -n py311 pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
conda run -n py311 pip install tensorstore |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please completely revert the changes to this file? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
import numpy as np | ||
import tensorstore as ts | ||
|
||
from sofima import stitch_elastic | ||
|
||
class CloudStorage(Enum): | ||
""" | ||
Documented Cloud Storage Options | ||
""" | ||
S3 = 1 | ||
GCS = 2 | ||
|
||
|
||
@dataclass | ||
class ZarrDataset: | ||
""" | ||
Parameters for locating Zarr dataset living on the cloud. | ||
Args: | ||
cloud_storage: CloudStorage option | ||
bucket: Name of bucket | ||
dataset_path: Path to directory containing zarr files within bucket | ||
tile_names: List of zarr tiles to include in dataset. | ||
Order of tile_names defines an index that | ||
is expected to be used in tile_layout. | ||
tile_layout: 2D array of indices that defines relative position of tiles. | ||
downsample_exp: Level in image pyramid with each level | ||
downsampling the original resolution by 2**downsmaple_exp. | ||
""" | ||
|
||
cloud_storage: CloudStorage | ||
bucket: str | ||
dataset_path: str | ||
tile_names: list[str] | ||
tile_layout: np.ndarray | ||
downsample_exp: int | ||
|
||
|
||
def open_zarr_gcs(bucket: str, path: str) -> ts.TensorStore: | ||
return ts.open({ | ||
'driver': 'zarr', | ||
'kvstore': { | ||
'driver': 'gcs', | ||
'bucket': bucket, | ||
}, | ||
'path': path, | ||
}).result() | ||
|
||
|
||
def open_zarr_s3(bucket: str, path: str) -> ts.TensorStore: | ||
return ts.open({ | ||
'driver': 'zarr', | ||
'kvstore': { | ||
'driver': 'http', | ||
'base_url': f'https://{bucket}.s3.us-west-2.amazonaws.com/{path}', | ||
}, | ||
}).result() | ||
|
||
|
||
def load_zarr_data(params: ZarrDataset | ||
) -> tuple[list[ts.TensorStore], stitch_elastic.ShapeXYZ]: | ||
""" | ||
Reads Zarr dataset from input location | ||
and returns list of equally-sized tensorstores | ||
in matching order as ZarrDataset.tile_names and tile size. | ||
Tensorstores are cropped to tiles at origin to the smallest tile in the set. | ||
""" | ||
|
||
def load_zarr(bucket: str, tile_location: str) -> ts.TensorStore: | ||
if params.cloud_storage == CloudStorage.S3: | ||
return open_zarr_s3(bucket, tile_location) | ||
else: # cloud == 'gcs' | ||
return open_zarr_gcs(bucket, tile_location) | ||
tile_volumes = [] | ||
min_x, min_y, min_z = np.inf, np.inf, np.inf | ||
for t_name in params.tile_names: | ||
tile_location = f"{params.dataset_path}/{t_name}/{params.downsample_exp}" | ||
tile = load_zarr(params.bucket, tile_location) | ||
tile_volumes.append(tile) | ||
|
||
_, _, tz, ty, tx = tile.shape | ||
min_x, min_y, min_z = int(np.minimum(min_x, tx)), \ | ||
int(np.minimum(min_y, ty)), \ | ||
int(np.minimum(min_z, tz)) | ||
tile_size_xyz = min_x, min_y, min_z | ||
|
||
# Standardize size of tile volumes | ||
for i, tile_vol in enumerate(tile_volumes): | ||
tile_volumes[i] = tile_vol[:, :, :min_z, :min_y, :min_x] | ||
|
||
return tile_volumes, tile_size_xyz | ||
|
||
|
||
def write_zarr(bucket: str, shape: list, path: str): | ||
""" | ||
Args: | ||
bucket: Name of gcs cloud storage bucket | ||
shape: 5D vector in tczyx order, ex: [1, 1, 3551, 576, 576] | ||
path: Output path inside bucket | ||
""" | ||
|
||
return ts.open({ | ||
'driver': 'zarr', | ||
'dtype': 'uint16', | ||
'kvstore' : { | ||
'driver': 'gcs', | ||
'bucket': bucket, | ||
}, | ||
'create': True, | ||
'delete_existing': True, | ||
'path': path, | ||
'metadata': { | ||
'chunks': [1, 1, 128, 256, 256], | ||
'compressor': { | ||
'blocksize': 0, | ||
'clevel': 1, | ||
'cname': 'zstd', | ||
'id': 'blosc', | ||
'shuffle': 1, | ||
}, | ||
'dimension_separator': '/', | ||
'dtype': '<u2', | ||
'fill_value': 0, | ||
'filters': None, | ||
'order': 'C', | ||
'shape': shape, | ||
'zarr_format': 2 | ||
} | ||
}).result() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please completely revert the changes to this file?