From 8c127321b7036bb556f429aef19209d1fed05986 Mon Sep 17 00:00:00 2001 From: Benjamin Rombaut Date: Wed, 27 Nov 2024 15:57:16 +0100 Subject: [PATCH] Add asv benchmark code (#784) * init for asv * ignore asv folder * add basic benchmarks * improve cluster_blobs creation time * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint code * pass pre-commit by ignore benchmark files. * add n_transcripts_per_cell to benchmark * lower benchmark size * add more benchmark documentation * add more benchmark documentation * updates from feedback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor edits docstrings --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> --- .gitignore | 2 + asv.conf.json | 203 ++++++++++++++++ benchmarks/README.md | 75 ++++++ benchmarks/__init__.py | 0 benchmarks/spatialdata_benchmark.py | 75 ++++++ benchmarks/utils.py | 351 ++++++++++++++++++++++++++++ pyproject.toml | 4 + 7 files changed, 710 insertions(+) create mode 100644 asv.conf.json create mode 100644 benchmarks/README.md create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/spatialdata_benchmark.py create mode 100644 benchmarks/utils.py diff --git a/.gitignore b/.gitignore index 328be6b2..439f8ae3 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,5 @@ _version.py # other node_modules/ + +.asv/ \ No newline at end of file diff --git a/asv.conf.json b/asv.conf.json new file mode 100644 index 00000000..8a108478 --- /dev/null +++ b/asv.conf.json @@ -0,0 +1,203 @@ +{ + // The version of the config file format. Do not change, unless + // you know what you are doing. + "version": 1, + + // The name of the project being benchmarked + "project": "spatialdata", + + // The project's homepage + "project_url": "https://spatialdata.scverse.org/", + + // The URL or local path of the source code repository for the + // project being benchmarked + "repo": ".", + + // The Python project's subdirectory in your repo. If missing or + // the empty string, the project is assumed to be located at the root + // of the repository. + // "repo_subdir": "", + + // Customizable commands for building the project. + // See asv.conf.json documentation. + // To build the package using pyproject.toml (PEP518), uncomment the following lines + // "build_command": [ + // "python -m pip install build", + // "python -m build", + // "python -mpip wheel -w {build_cache_dir} {build_dir}" + // ], + // To build the package using setuptools and a setup.py file, uncomment the following lines + // "build_command": [ + // "python setup.py build", + // "python -mpip wheel -w {build_cache_dir} {build_dir}" + // ], + "build_command": ["python -V"], // skip build stage + + // Customizable commands for installing and uninstalling the project. + // See asv.conf.json documentation. + // "install_command": ["in-dir={env_dir} python -mpip install {wheel_file}"], + // "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"], + + // Install using default install + "install_command": [ + "in-dir={env_dir} python -m pip install {build_dir}[test]" + ], + "uninstall_command": [ + "in-dir={env_dir} python -m pip uninstall -y {project}" + ], + + // List of branches to benchmark. If not provided, defaults to "main" + // (for git) or "default" (for mercurial). + "branches": ["main"], // for git + // "branches": ["default"], // for mercurial + + // The DVCS being used. If not set, it will be automatically + // determined from "repo" by looking at the protocol in the URL + // (if remote), or by looking for special directories, such as + // ".git" (if local). + // "dvcs": "git", + + // The tool to use to create environments. May be "conda", + // "virtualenv", "mamba" (above 3.8) + // or other value depending on the plugins in use. + // If missing or the empty string, the tool will be automatically + // determined by looking for tools on the PATH environment + // variable. + "environment_type": "virtualenv", + + // timeout in seconds for installing any dependencies in environment + // defaults to 10 min + // "install_timeout": 600, + + // the base URL to show a commit for the project. + // "show_commit_url": "http://github.com/owner/project/commit/", + + // The Pythons you'd like to test against. If not provided, defaults + // to the current version of Python used to run `asv`. + "pythons": ["3.12"], + + // The list of conda channel names to be searched for benchmark + // dependency packages in the specified order + // "conda_channels": ["conda-forge", "defaults"], + + // A conda environment file that is used for environment creation. + // "conda_environment_file": "environment.yml", + + // The matrix of dependencies to test. Each key of the "req" + // requirements dictionary is the name of a package (in PyPI) and + // the values are version numbers. An empty list or empty string + // indicates to just test against the default (latest) + // version. null indicates that the package is to not be + // installed. If the package to be tested is only available from + // PyPi, and the 'environment_type' is conda, then you can preface + // the package name by 'pip+', and the package will be installed + // via pip (with all the conda available packages installed first, + // followed by the pip installed packages). + // + // The ``@env`` and ``@env_nobuild`` keys contain the matrix of + // environment variables to pass to build and benchmark commands. + // An environment will be created for every combination of the + // cartesian product of the "@env" variables in this matrix. + // Variables in "@env_nobuild" will be passed to every environment + // during the benchmark phase, but will not trigger creation of + // new environments. A value of ``null`` means that the variable + // will not be set for the current combination. + // + // "matrix": { + // "req": { + // "numpy": ["1.6", "1.7"], + // "six": ["", null], // test with and without six installed + // "pip+emcee": [""] // emcee is only available for install with pip. + // }, + // "env": {"ENV_VAR_1": ["val1", "val2"]}, + // "env_nobuild": {"ENV_VAR_2": ["val3", null]}, + // }, + + // Combinations of libraries/python versions can be excluded/included + // from the set to test. Each entry is a dictionary containing additional + // key-value pairs to include/exclude. + // + // An exclude entry excludes entries where all values match. The + // values are regexps that should match the whole string. + // + // An include entry adds an environment. Only the packages listed + // are installed. The 'python' key is required. The exclude rules + // do not apply to includes. + // + // In addition to package names, the following keys are available: + // + // - python + // Python version, as in the *pythons* variable above. + // - environment_type + // Environment type, as above. + // - sys_platform + // Platform, as in sys.platform. Possible values for the common + // cases: 'linux2', 'win32', 'cygwin', 'darwin'. + // - req + // Required packages + // - env + // Environment variables + // - env_nobuild + // Non-build environment variables + // + // "exclude": [ + // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows + // {"environment_type": "conda", "req": {"six": null}}, // don't run without six on conda + // {"env": {"ENV_VAR_1": "val2"}}, // skip val2 for ENV_VAR_1 + // ], + // + // "include": [ + // // additional env for python3.12 + // {"python": "3.12", "req": {"numpy": "1.26"}, "env_nobuild": {"FOO": "123"}}, + // // additional env if run on windows+conda + // {"platform": "win32", "environment_type": "conda", "python": "3.12", "req": {"libpython": ""}}, + // ], + + // The directory (relative to the current directory) that benchmarks are + // stored in. If not provided, defaults to "benchmarks" + // "benchmark_dir": "benchmarks", + + // The directory (relative to the current directory) to cache the Python + // environments in. If not provided, defaults to "env" + "env_dir": ".asv/env", + + // The directory (relative to the current directory) that raw benchmark + // results are stored in. If not provided, defaults to "results". + "results_dir": ".asv/results", + + // The directory (relative to the current directory) that the html tree + // should be written to. If not provided, defaults to "html". + "html_dir": ".asv/html", + + // The number of characters to retain in the commit hashes. + "hash_length": 8, + + // `asv` will cache results of the recent builds in each + // environment, making them faster to install next time. This is + // the number of builds to keep, per environment. + "build_cache_size": 2 + + // The commits after which the regression search in `asv publish` + // should start looking for regressions. Dictionary whose keys are + // regexps matching to benchmark names, and values corresponding to + // the commit (exclusive) after which to start looking for + // regressions. The default is to start from the first commit + // with results. If the commit is `null`, regression detection is + // skipped for the matching benchmark. + // + // "regressions_first_commits": { + // "some_benchmark": "352cdf", // Consider regressions only after this commit + // "another_benchmark": null, // Skip regression detection altogether + // }, + + // The thresholds for relative change in results, after which `asv + // publish` starts reporting regressions. Dictionary of the same + // form as in ``regressions_first_commits``, with values + // indicating the thresholds. If multiple entries match, the + // maximum is taken. If no entry matches, the default is 5%. + // + // "regressions_thresholds": { + // "some_benchmark": 0.01, // Threshold of 1% + // "another_benchmark": 0.5, // Threshold of 50% + // }, +} diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 00000000..6c69a786 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,75 @@ +# Benchmarking SpatialData code + +This `benchmarks/` folder contains code to benchmark the performance of the SpatialData code. You can use it to see how code behaves for different options or data sizes. For more information, check the [SpatialData Contributing Guide](https://spatialdata.scverse.org/en/stable/contributing.html). + +Note that to run code, your current working directory should be the SpatialData repo, not this `benchmarks/` folder. + +## Installation + +The benchmarks use the [airspeed velocity](https://asv.readthedocs.io/en/stable/) (asv) framework. Install it with the `benchmark` option: + +``` +pip install -e '.[docs,test,benchmark]' +``` + +## Usage + +Running all the benchmarks is usually not needed. You run the benchmark using `asv run`. See the [asv documentation](https://asv.readthedocs.io/en/stable/commands.html#asv-run) for interesting arguments, like selecting the benchmarks you're interested in by providing a regex pattern `-b` or `--bench` that links to a function or class method e.g. the option `-b timeraw_import_inspect` selects the function `timeraw_import_inspect` in `benchmarks/spatialdata_benchmark.py`. You can run the benchmark in your current environment with `--python=same`. Some example benchmarks: + +Importing the SpatialData library can take around 4 seconds: + +``` +PYTHONWARNINGS="ignore" asv run --python=same --show-stderr -b timeraw_import_inspect +Couldn't load asv.plugins._mamba_helpers because +No module named 'conda' +· Discovering benchmarks +· Running 1 total benchmarks (1 commits * 1 environments * 1 benchmarks) +[ 0.00%] ·· Benchmarking existing-py_opt_homebrew_Caskroom_mambaforge_base_envs_spatialdata2_bin_python3.12 +[50.00%] ··· Running (spatialdata_benchmark.timeraw_import_inspect--). +[100.00%] ··· spatialdata_benchmark.timeraw_import_inspect 3.65±0.2s +``` + +Querying using a bounding box without a spatial index is highly impacted by large amounts of points (transcripts), more than table rows (cells). + +``` +$ PYTHONWARNINGS="ignore" asv run --python=same --show-stderr -b time_query_bounding_box + +[100.00%] ··· ======== ============ ============= ============= ============== + -- filter_table / n_transcripts_per_cell + -------- ------------------------------------------------------- + length True / 100 True / 1000 False / 100 False / 1000 + ======== ============ ============= ============= ============== + 100 177±5ms 195±4ms 168±0.5ms 186±2ms + 1000 195±3ms 402±2ms 187±3ms 374±4ms + 10000 722±3ms 2.65±0.01s 389±3ms 2.22±0.02s + ======== ============ ============= ============= ============== +``` + +You can use `asv` to run all the benchmarks in their own environment. This can take a long time, so it is not recommended for regular use: + +``` +$ asv run +Couldn't load asv.plugins._mamba_helpers because +No module named 'conda' +· Creating environments.... +· Discovering benchmarks.. +·· Uninstalling from virtualenv-py3.12 +·· Building a89d16d8 for virtualenv-py3.12 +·· Installing a89d16d8 into virtualenv-py3.12............. +· Running 6 total benchmarks (1 commits * 1 environments * 6 benchmarks) +[ 0.00%] · For spatialdata commit a89d16d8 : +[ 0.00%] ·· Benchmarking virtualenv-py3.12 +[25.00%] ··· Running (spatialdata_benchmark.TimeMapRaster.time_map_blocks--)... +... +[100.00%] ··· spatialdata_benchmark.timeraw_import_inspect 3.33±0.06s +``` + +## Notes + +When using PyCharm, remember to set [Configuration](https://www.jetbrains.com/help/pycharm/run-debug-configuration.html) to include the benchmark module, as this is separate from the main code module. + +In Python, you can run a module using the following command: + +``` +python -m benchmarks.spatialdata_benchmark +``` diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarks/spatialdata_benchmark.py b/benchmarks/spatialdata_benchmark.py new file mode 100644 index 00000000..af383556 --- /dev/null +++ b/benchmarks/spatialdata_benchmark.py @@ -0,0 +1,75 @@ +# type: ignore + +# Write the benchmarking functions here. +# See "Writing benchmarks" in the asv docs for more information. +import spatialdata as sd + +from .utils import cluster_blobs + + +class MemorySpatialData: + # TODO: see what the memory overhead is e.g. Python interpreter... + """Calculate the peak memory usage is for artificial datasets with increasing channels.""" + + def peakmem_list(self): + sdata: sd.SpatialData = sd.datasets.blobs(n_channels=1) + return sdata + + def peakmem_list2(self): + sdata: sd.SpatialData = sd.datasets.blobs(n_channels=2) + return sdata + + +def timeraw_import_inspect(): + """Time the import of the spatialdata module.""" + return """ + import spatialdata + """ + + +class TimeMapRaster: + """Time the.""" + + params = [100, 1000, 10_000] + param_names = ["length"] + + def setup(self, length): + self.sdata = cluster_blobs(length=length) + + def teardown(self, _): + del self.sdata + + def time_map_blocks(self, _): + sd.map_raster(self.sdata["blobs_image"], lambda x: x + 1) + + +class TimeQueries: + + params = ([100, 1_000, 10_000], [True, False], [100, 1_000]) + param_names = ["length", "filter_table", "n_transcripts_per_cell"] + + def setup(self, length, filter_table, n_transcripts_per_cell): + import shapely + + self.sdata = cluster_blobs(length=length, n_transcripts_per_cell=n_transcripts_per_cell) + self.polygon = shapely.box(0, 0, length // 2, length // 2) + + def teardown(self, length, filter_table, n_transcripts_per_cell): + del self.sdata + + def time_query_bounding_box(self, length, filter_table, n_transcripts_per_cell): + self.sdata.query.bounding_box( + axes=["x", "y"], + min_coordinate=[0, 0], + max_coordinate=[length // 2, length // 2], + target_coordinate_system="global", + filter_table=filter_table, + ) + + def time_query_polygon_box(self, length, filter_table, n_transcripts_per_cell): + sd.polygon_query( + self.sdata, + self.polygon, + target_coordinate_system="global", + filter_table=filter_table, + ) diff --git a/benchmarks/utils.py b/benchmarks/utils.py new file mode 100644 index 00000000..75f566b7 --- /dev/null +++ b/benchmarks/utils.py @@ -0,0 +1,351 @@ +# type: ignore +""" +This utility module contains functions that are used in the benchmarks. + +Functions that make running and debugging benchmarks easier: +- class Skip is used to skip benchmarks based on environment variables. +- function run_benchmark_from_module is used to run benchmarks from a module. +- function run_benchmark is used to run the benchmarks. + +Performant dataset generation functions so the benchmarks run fast even for large artificial datasets. +The object is to generate a dataset containing many cells. By copying the same cell values instead of +doing gaussian blur on the whole image, we can generate the same dataset in a fraction of the time. +- function labeled_particles is used to generate labeled blobs. +- function _generate_ball is used to generate a ball of given radius and dimension. +- function _generate_density is used to generate gaussian density of given radius and dimension. +- function cluster_blobs is used to generate a SpatialData object with blobs. +- function _structure_at_coordinates is used to update data with structure at given coordinates. +- function _get_slices_at is used to get slices at a given point. +- function _update_data_with_mask is used to update data with struct where struct is nonzero. +""" + +import itertools +import os +from collections.abc import Sequence +from functools import lru_cache +from types import ModuleType +from typing import Callable, Literal, Optional, Union, overload + +import anndata as ad +import numpy as np +import pandas as pd +from skimage import morphology + +import spatialdata as sd +from spatialdata import SpatialData +from spatialdata.models import Image2DModel, TableModel +from spatialdata.transformations import Identity + + +def always_false(*_): + return False + + +class Skip: + def __init__( + self, + if_in_pr: Callable[..., bool] = always_false, + if_on_ci: Callable[..., bool] = always_false, + always: Callable[..., bool] = always_false, + ): + self.func_pr = if_in_pr if "PR" in os.environ else always_false + self.func_ci = if_on_ci if "CI" in os.environ else always_false + self.func_always = always + + def __contains__(self, item): + return self.func_pr(*item) or self.func_ci(*item) or self.func_always(*item) + + +def _generate_ball(radius: int, ndim: int) -> np.ndarray: + """Generate a ball of given radius and dimension. + + Parameters + ---------- + radius : int + Radius of the ball. + ndim : int + Dimension of the ball. + + Returns + ------- + ball : ndarray of uint8 + Binary array of the hyper ball. + """ + if ndim == 2: + return morphology.disk(radius) + if ndim == 3: + return morphology.ball(radius) + shape = (2 * radius + 1,) * ndim + radius_sq = radius**2 + coords = np.indices(shape) - radius + return (np.sum(coords**2, axis=0) <= radius_sq).astype(np.uint8) + + +def _generate_density(radius: int, ndim: int) -> np.ndarray: + """Generate gaussian density of given radius and dimension.""" + shape = (2 * radius + 1,) * ndim + coords = np.indices(shape) - radius + dist = np.sqrt(np.sum(coords**2 / ((radius / 4) ** 2), axis=0)) + res = np.exp(-dist) + res[res < 0.02] = 0 + return res + + +def _structure_at_coordinates( + shape: tuple[int], + coordinates: np.ndarray, + structure: np.ndarray, + *, + multipliers: Sequence = itertools.repeat(1), + dtype=None, + reduce_fn: Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], np.ndarray], +): + """Update data with structure at given coordinates. + + Parameters + ---------- + data : ndarray + Array to update. + coordinates : ndarray + Coordinates of the points. The structures will be added at these + points (center). + structure : ndarray + Array with encoded structure. For example, ball (boolean) or density + (0,1) float. + multipliers : ndarray + These values are multiplied by the values in the structure before + updating the array. Can be used to generate different labels, or to + vary the intensity of floating point gaussian densities. + reduce_fn : function + Function with which to update the array at a particular position. It + should take two arrays as input and an optional output array. + """ + radius = (structure.shape[0] - 1) // 2 + data = np.zeros(shape, dtype=dtype) + + for point, value in zip(coordinates, multipliers): + slice_im, slice_ball = _get_slices_at(shape, point, radius) + reduce_fn(data[slice_im], value * structure[slice_ball], out=data[slice_im]) + return data + + +def _get_slices_at(shape, point, radius): + slice_im = [] + slice_ball = [] + for i, p in enumerate(point): + slice_im.append(slice(max(0, p - radius), min(shape[i], p + radius + 1))) + ball_start = max(0, radius - p) + ball_stop = slice_im[-1].stop - slice_im[-1].start + ball_start + slice_ball.append(slice(ball_start, ball_stop)) + return tuple(slice_im), tuple(slice_ball) + + +def _update_data_with_mask(data, struct, out=None): + """Update ``data`` with ``struct`` where ``struct`` is nonzero.""" + # these branches are needed because np.where does not support + # an out= keyword argument + if out is None: + return np.where(struct, struct, data) + else: # noqa: RET505 + nz = struct != 0 + out[nz] = struct[nz] + return out + + +def _smallest_dtype(n: int) -> np.dtype: + """Find the smallest dtype that can hold n values.""" + for dtype in [np.uint8, np.uint16, np.uint32, np.uint64]: + if np.iinfo(dtype).max >= n: + return dtype + break + else: + raise ValueError(f"{n=} is too large for any dtype.") + + +@overload +def labeled_particles( + shape: Sequence[int], + dtype: Optional[np.dtype] = None, + n: int = 144, + seed: Optional[int] = None, + return_density: Literal[False] = False, +) -> np.ndarray: ... + + +@overload +def labeled_particles( + shape: Sequence[int], + dtype: Optional[np.dtype] = None, + n: int = 144, + seed: Optional[int] = None, + return_density: Literal[True] = True, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ... + + +@lru_cache +def labeled_particles( + shape: Sequence[int], + dtype: Optional[np.dtype] = None, + n: int = 144, + seed: Optional[int] = None, + return_density: bool = False, +) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray]]: + """Generate labeled blobs of given shape and dtype. + + Parameters + ---------- + shape : Sequence[int] + Shape of the resulting array. + dtype : Optional[np.dtype] + Dtype of the resulting array. + n : int + Number of blobs to generate. + seed : Optional[int] + Seed for the random number generator. + return_density : bool + Whether to return the density array and center coordinates. + """ + if dtype is None: + dtype = _smallest_dtype(n) + rng = np.random.default_rng(seed) + ndim = len(shape) + points = rng.integers(shape, size=(n, ndim)) + # create values from 1 to max of number of points + values = np.linspace(1, n, n, dtype=dtype) + rng.shuffle(values) + # values = rng.integers( + # np.iinfo(dtype).min + 1, np.iinfo(dtype).max, size=n, dtype=dtype + # ) + sigma = int(max(shape) / (4.0 * n ** (1 / ndim))) + ball = _generate_ball(sigma, ndim) + + labels = _structure_at_coordinates( + shape, + points, + ball, + multipliers=values, + reduce_fn=_update_data_with_mask, + dtype=dtype, + ) + + if return_density: + dens = _generate_density(sigma * 2, ndim) + densities = _structure_at_coordinates(shape, points, dens, reduce_fn=np.maximum, dtype=np.float32) + + return labels, densities, points, values + else: # noqa: RET505 + return labels + + +def run_benchmark_from_module(module: ModuleType, klass_name: str, method_name: str): + klass = getattr(module, klass_name) + if getattr(klass, "params", None): + skip_if = getattr(klass, "skip_params", {}) + if isinstance(klass.params[0], Sequence): + params = itertools.product(*klass.params) + else: + params = ((i,) for i in klass.params) + for param in params: + if param in skip_if: + continue + obj = klass() + try: + obj.setup(*param) + except NotImplementedError: + continue + getattr(obj, method_name)(*param) + getattr(obj, "teardown", lambda: None)() + else: + obj = klass() + try: + obj.setup() + except NotImplementedError: + return + getattr(obj, method_name)() + getattr(obj, "teardown", lambda: None)() + + +def run_benchmark(): + import argparse + import inspect + + parser = argparse.ArgumentParser(description="Run benchmark") + parser.add_argument("benchmark", type=str, help="Name of the benchmark to run", default="") + + args = parser.parse_args() + + benchmark_selection = args.benchmark.split(".") + + # get module of parent frame + call_module = inspect.getmodule(inspect.currentframe().f_back) + run_benchmark_from_module(call_module, *benchmark_selection) + + +# TODO: merge functionality of this cluster_blobs with the one in SpatialData https://github.com/scverse/spatialdata/issues/796 +@lru_cache +def cluster_blobs( + length=512, + n_cells=None, + region_key="region_key", + instance_key="instance_key", + image_name="blobs_image", + labels_name="blobs_labels", + points_name="blobs_points", + n_transcripts_per_cell=None, + table_name="table", + coordinate_system="global", +): + """Faster `spatialdata.datasets.make_blobs` using napari.datasets code.""" + if n_cells is None: + n_cells = length + # cells + labels, density, points, values = labeled_particles((length, length), return_density=True, n=n_cells) + + im_el = Image2DModel.parse( + data=density[None, ...], + dims="cyx", + transformations={coordinate_system: Identity()}, + ) + label_el = sd.models.Labels2DModel.parse(labels, dims="yx", transformations={coordinate_system: Identity()}) + points_cells_el = sd.models.PointsModel.parse(points, transformations={coordinate_system: Identity()}) + + # generate dummy table + adata = ad.AnnData(X=np.ones((length, 10))) + adata.obs[region_key] = pd.Categorical([labels_name] * len(adata)) + # adata.obs_names = values.astype(np.uint64) + adata.obs[instance_key] = adata.obs_names.values + adata.obs.index = adata.obs.index.astype(str) + adata.obs.index.name = instance_key + # del adata.uns[TableModel.ATTRS_KEY] + table = TableModel.parse( + adata, + region=labels_name, + region_key=region_key, + instance_key=instance_key, + ) + + sdata = SpatialData( + images={ + image_name: im_el, + }, + labels={ + labels_name: label_el, + }, + points={points_name: points_cells_el}, + tables={table_name: table}, + ) + + if n_transcripts_per_cell: + # transcript points + # generate 100 transcripts per cell + rng = np.random.default_rng(None) + points_transcripts = rng.integers(length, size=(n_cells * n_transcripts_per_cell, 2)) + points_transcripts_el = sd.models.PointsModel.parse( + points_transcripts, transformations={coordinate_system: Identity()} + ) + sdata["transcripts_" + points_name] = points_transcripts_el + + # if shapes_name: + # sdata[shapes_name] = sd.to_circles(sdata[labels_name]) + # add_regionprop_features(sdata, labels_layer=labels_name, table_layer=table_name) + return sdata diff --git a/pyproject.toml b/pyproject.toml index ddb4b0b7..da2f6559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,9 @@ test = [ "pytest-mock", "torch", ] +benchmark = [ + "asv", +] torch = [ "torch" ] @@ -204,6 +207,7 @@ convention = "numpy" "src/spatialdata/dataloader/datasets.py" = ["D101"] "tests/test_models/test_models.py" = ["NPY002"] "tests/conftest.py"= ["E402"] + "benchmarks/*" = ["ALL"] # pyupgrade typing rewrite TODO: remove at some point from per-file ignore