Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve data loader performance #565

Closed
wants to merge 4 commits into from
Closed

improve data loader performance #565

wants to merge 4 commits into from

Conversation

giovp
Copy link
Member

@giovp giovp commented May 24, 2024

so I've been wanting to take another look at this for a long time, I used https://github.com/benfred/py-spy with speedscope format, you can see screenshot below.
image

I've been doing this on the xenium_rep_1 dataset from the paper, and been using the following code (adapting from @LucaMarconato code ):

Details
import json

import numpy as np
import pandas as pd
import torchvision.transforms.v2 as T
from spatialdata.dataloader.datasets import ImageTilesDataset
from spatialdata.transformations import Scale, get_transformation
from spatialdata.transformations import Sequence as SequenceTransformation
from torch.utils.data import DataLoader
from tqdm import tqdm

from pathlib import Path
import spatialdata as sd

xeniumrep1 = Path(
  "/path/to/xenium_rep1_data_aligned.zarr"
)
sdata1 = sd.read_zarr(xeniumrep1)

visium = Path(
    "/path/to/visium_data_aligned.zarr"
)
sdata3 = sd.read_zarr(visium)

TILE_SCALE = 10.0
REGION = "xeniumrep1"
sdata = sdata1
sdata.images["hne"] = sdata3.images["CytAssist_FFPE_Human_Breast_Cancer_full_image"]

def get_ds(sdata: sd.SpatialData):
    img_size = 224

    transform_tv = T.Compose(
        [
            T.ToImage(),
            T.Resize((img_size, img_size), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
        ]
    )

    def transform(output):
        image, anno = output
        instance_id, celltype = anno[:, 0].squeeze(), anno[:, 1].squeeze()
        image = transform_tv(image.data.transpose(1, 2, 0).compute(scheduler="single-threaded"))
        out = {"img": image, "instance_id": instance_id.tolist(), "celltype": celltype.tolist()}
        return out

    mu = sdata.shapes["cell_circles"]["radius"].mean()
    std = sdata.shapes["cell_circles"]["radius"].std()
    # large radius to cover most of the cells
    large_radius = mu + 2 * std
    neighbors_contex = large_radius
    sdata.shapes["cell_circles"]["radius"] = neighbors_contex
    instance_key = sdata.tables["table"].uns["spatialdata_attrs"]["instance_key"]

    ds = ImageTilesDataset(
        sdata=sdata,
        regions_to_images={"cell_circles": "hne"},
        regions_to_coordinate_systems={"cell_circles": "aligned"},
        return_annotations=[instance_key, "celltype_major"],
        tile_scale=TILE_SCALE,
        transform=transform,
        table_name="table",
    )
    return ds

ds = get_ds(sdata)
dl = DataLoader(
    ds,
    batch_size=256,
    num_workers=0,
    shuffle=False,
)

this made me realize that, if we want to return the array, than there is an unnecessary step of instantiating the SpatialImage|MultiscaleSpatialImage that is not necessary, and the dask array could be simply returned. This halved the fetch step (across 6 iterations) from ~43s to ~23s total, see below
image

I think the fetch step is what ultimately we want to improve, as it's the one that stream the tiles from the zarr array to the GPU. Now the two main blocks are the transform call and the compute call. The transform call is visualized under compute but it's effectively the wrapper call, where all the DataArray.isel happen, which is where the crops are defined, transformed and set, before the computation is actually triggered with compute.
image
I wonder what could be the next step here to chase performance gain: I think one option would be to basically "prepare" the transformation before on the full array, and then trigger it only at the tile creation in the compute (whereas now, transformation and tile creation is done jointly for each tile). This I think would require significant refactoring though so I wonder if it makes sense at all, and if anyone has other ideas to explore @scverse/spatialdata

Copy link

codecov bot commented May 24, 2024

Codecov Report

Attention: Patch coverage is 71.42857% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 92.52%. Comparing base (8d902d4) to head (7adc03f).
Report is 8 commits behind head on main.

Current head 7adc03f differs from pull request most recent head 7feb03b

Please upload reports for the commit 7feb03b to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #565      +/-   ##
==========================================
- Coverage   92.53%   92.52%   -0.02%     
==========================================
  Files          43       42       -1     
  Lines        6003     6008       +5     
==========================================
+ Hits         5555     5559       +4     
- Misses        448      449       +1     
Files Coverage Δ
src/spatialdata/dataloader/datasets.py 90.73% <100.00%> (+0.04%) ⬆️
src/spatialdata/_core/query/spatial_query.py 94.67% <50.00%> (-0.51%) ⬇️

... and 6 files with indirect coverage changes

@LucaMarconato
Copy link
Member

Super cool analysis! I'll also try it out (which commands did you use to open py-spy? Or did you set it up to be integrated with your IDE?)

If most of the time is spent outside dask_image.ndinterp.affine_transform() (the core function used in transform()), then I think that preparing everything before and calling affine_transform() at the end would be a good approach.

But my bet (I need to check by running the profiler), is that the problem is that we load multiple times the same chunks. I think that maybe using .persist() to automatically cache some Dask chunks, and to order the cells so that we randomize the chunks first, and then the cells inside a chunk, would lead to performance improvements.

This second approach has the advantage that it involves only the dataloader class and does not require changes in the transformation code.

@LucaMarconato
Copy link
Member

I reviewed the code, looks good to me. We could merge this already or explore first the .persist() approach above in this PR.

@@ -81,6 +81,10 @@ class ImageTilesDataset(Dataset):
system; this back-transforms the target tile into the pixel coordinates. If the back-transformed tile is not
aligned with the pixel grid, the returned tile will correspond to the bounding box of the back-transformed tile
(so that the returned tile is axis-aligned to the pixel grid).
return_genes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Two comments:

  1. I would specify that the layers are AnnData layers and the default layer is X.
  2. I would also allow to pass just a list instead of a dict, that would be interpreted as {'X': genes_list}

@giovp
Copy link
Member Author

giovp commented May 28, 2024

I'll also try it out (which commands did you use to open py-spy? Or did you set it up to be integrated with your IDE?)

I've just changed the format in py-spy
py-spy record --format speedscope -o profile.speedscope.json -- python process_xenium.py

this was just a push to get the code in another machine. But let me explain what's next.

I've realized that the calculation of the transformed bounding box in the implicit coordinate system takes a fair amount of time and it could in fact be done only in the same way the tile coords dataset is built. I will therefore:

  • move out the transformation from the bounding box query and do it only once at init.
  • Enable to return gexp data from different layers.

The dataset will have only type of output which will be dictionary of the following

{
	"tile":tile,
	"annotations":list of annotations,
	"gexp": list of gexp,
}

wdyt?

What I won't do here but would be useful to work on next is:

@LucaMarconato
Copy link
Member

Thanks for the explanation. Yes, I think that operating on the transformation at the preprocessing stage is a good approach to improve performance. Also, the option to specify the layer will be useful.

Regarding the return type, would you remove the SpatialData return type or still leave it as an option?

@giovp
Copy link
Member Author

giovp commented May 28, 2024

Regarding the return type, would you remove the SpatialData return type or still leave it as an option?

that's a good question, I would potentially leave it but then technically the dataloader would fail as the default collate_fn only accepts array/mapping[str, array]/list[array], wdyt?

@LucaMarconato
Copy link
Member

Ok, then I would probably move the default away from returning SpatialData (but still leave it as an option to the users). I think a good default would be one compatible with the default collate_fn.

@giovp
Copy link
Member Author

giovp commented Sep 3, 2024

close in favour of #687

@giovp giovp closed this Sep 3, 2024
@giovp giovp deleted the giovp/dataloader branch September 3, 2024 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants