Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Points not transformed when method="datashader" #337

Open
clwgg opened this issue Aug 29, 2024 · 5 comments · May be fixed by #378
Open

Points not transformed when method="datashader" #337

clwgg opened this issue Aug 29, 2024 · 5 comments · May be fixed by #378
Assignees
Labels
bug Something isn't working points 🧮 Anything related to Points priority: medium

Comments

@clwgg
Copy link

clwgg commented Aug 29, 2024

As per the title, I just ran into a case where datashader was chosen as the method for render_points, which led to my points being plotted without the relevant transformation being applied. I stole the example from #182 for testing below.

from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel
from spatialdata.transformations import Scale
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import spatialdata_plot

sdata = SpatialData(
    images={
        "image1": Image2DModel.parse(
            np.full((10, 10, 3), fill_value=128), dims=("y", "x", "c")
        )
    },
    points={
        "points1": PointsModel.parse(
            pd.DataFrame({"y": [0.1, 0.1, 0.9, 0.9], "x": [0.1, 0.9, 0.9, 0.1]}),
            transformations={"global": Scale([10, 10], ("y", "x"))},
        )
    },
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
sdata.pl.render_images("image1").pl.render_points("points1", method="datashader").pl.show(ax=ax1, title="datashader")
sdata.pl.render_images("image1").pl.render_points("points1", method="matplotlib").pl.show(ax=ax2, title="matplotlib")

With current main:
Screenshot 2024-08-28 at 10 01 40 PM

With #309:
Screenshot 2024-08-28 at 10 02 09 PM

@timtreis timtreis added bug Something isn't working priority: medium points 🧮 Anything related to Points labels Aug 29, 2024
@LucaMarconato
Copy link
Member

Thanks for reporting. Please see the discussion on this issue also here: #291.

@Marius1311
Copy link

I think I'm seeing a consequence of that in my own data. Calling

(
    sdata_cropped
    .pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="matplotlib")
    .pl.show()
)

works just fine, but when using method="datashader", I get

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[45], line 4
      1 (
      2     sdata_cropped
      3     .pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="datashader")
----> 4     .pl.show()
      5 )

File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py:895](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py#line=894), in PlotAccessor.show(self, coordinate_systems, legend_fontsize, legend_fontweight, legend_loc, legend_fontoutline, na_in_legend, colorbar, wspace, hspace, ncols, frameon, figsize, dpi, fig, title, share_extent, pad_extent, ax, return_ax, save)
    890     wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements(
    891         sdata, wanted_elements, params_copy, cs, "points"
    892     )
    894     if wanted_points_on_this_cs:
--> 895         _render_points(
    896             sdata=sdata,
    897             render_params=params_copy,
    898             coordinate_system=cs,
    899             ax=ax,
    900             fig_params=fig_params,
    901             scalebar_params=scalebar_params,
    902             legend_params=legend_params,
    903         )
    905 elif cmd == "render_labels" and has_labels:
    906     wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements(
    907         sdata, wanted_elements, params_copy, cs, "labels"
    908     )

File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py:483](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py#line=482), in _render_points(sdata, render_params, coordinate_system, ax, fig_params, scalebar_params, legend_params)
    466     color_vector = np.asarray([x[:-2] for x in color_vector])
    468 ds_result = (
    469     ds.tf.shade(
    470         ds.tf.spread(agg, px=px),
   (...)
    481     )
    482 )
--> 483 rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
    484 cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha)
    485 if aggregate_with_sum is not None:

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:655](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=654), in transpose(a, axes)
    588 @array_function_dispatch(_transpose_dispatcher)
    589 def transpose(a, axes=None):
    590     """
    591     Returns an array with axes transposed.
    592 
   (...)
    653 
    654     """
--> 655     return _wrapfunc(a, 'transpose', axes)

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:56](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=55), in _wrapfunc(obj, method, *args, **kwds)
     54 bound = getattr(obj, method, None)
     55 if bound is None:
---> 56     return _wrapit(obj, method, *args, **kwds)
     58 try:
     59     return bound(*args, **kwds)

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:45](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=44), in _wrapit(obj, method, *args, **kwds)
     43 except AttributeError:
     44     wrap = None
---> 45 result = getattr(asarray(obj), method)(*args, **kwds)
     46 if wrap:
     47     if not isinstance(result, mu.ndarray):

ValueError: axes don't match array

@LucaMarconato
Copy link
Member

LucaMarconato commented Sep 30, 2024

@Marius1311 thanks for reporting. How did you construct sdata_cropped? It would be helpful for us if you could please reproduce your bug using the blobs dataset.

You can access it via one of these two functions:

CC @melonora

@Sonja-Stockhaus
Copy link
Collaborator

@clwgg Thanks for reporting! I reproduced the problem without the image in the background which led to the points being shifted by 0.5 when using datashader (because of #216).

from spatialdata import SpatialData
from spatialdata.models import PointsModel
from spatialdata.transformations import Scale

sdata = SpatialData(
    points={
        "points1": PointsModel.parse(
            pd.DataFrame({"y": [0, 0, 10, 10, 4, 6, 4, 6], "x": [0, 10, 10, 0, 4, 6, 6, 4]}),
            transformations={"global": Scale([2, 2], ("y", "x"))},
        )
    },
)
sdata.pl.render_points("points1", method="matplotlib", size=50, color="lightgrey").pl.render_points("points1", method="datashader", size=10, color="red").pl.show()

With this, I get a) before:

Image

b) after my fix (#378):
Image

@timtreis
Copy link
Member

@clwgg could you verify that Sonja's branch fixes the issue for you as well? :) Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working points 🧮 Anything related to Points priority: medium
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants