From debb1a16f67620234eddbee710413036a0d23ab1 Mon Sep 17 00:00:00 2001 From: Dylan Lam Date: Fri, 5 Apr 2024 09:29:53 -0700 Subject: [PATCH 1/2] _set_color_source_vec() refactor Refactored _set_color_source_vec to remove catches for tables annotating point elements. get_values can now grab annotations for points from tables. --- src/spatialdata_plot/pl/utils.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index caeb4c00..52767dd7 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -634,16 +634,8 @@ def _set_color_source_vec( color = np.full(len(element), to_hex(na_color)) # type: ignore[arg-type] return color, color, False - model = get_model(sdata[element_name]) - # Figure out where to get the color from origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name) - if model == PointsModel and table_name is not None: - origin = _locate_points_value_in_table( - value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name - ) - if origin is not None: - origins.append(origin) if len(origins) > 1: raise ValueError( @@ -651,11 +643,8 @@ def _set_color_source_vec( ) if len(origins) == 1: - if model == PointsModel and table_name is not None: - color_source_vector = get_values_point_table(sdata=sdata, origin=origin, table_name=table_name) - else: - vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name) - color_source_vector = vals[value_to_plot] + vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name) + color_source_vector = vals[value_to_plot] # numerical case, return early if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype): From c7f39226d68401ecc25f51f895e7ad75778c7c25 Mon Sep 17 00:00:00 2001 From: Dylan Lam Date: Fri, 5 Apr 2024 09:30:58 -0700 Subject: [PATCH 2/2] _render_points() refactor refactored _render points to be able to plot points and color by annotations from the element itself and from tables annotating points. --- src/spatialdata_plot/pl/render.py | 69 ++++++------------------------- 1 file changed, 12 insertions(+), 57 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 736d3aa5..5b4d08d5 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -45,6 +45,7 @@ _normalize, _rasterize_if_necessary, _set_color_source_vec, + _get_palette, to_hex, ) @@ -230,64 +231,11 @@ def _render_points( table_name = element_table_mapping.get(e) coords = ["x", "y"] - # if col_for_color is not None: - if ( - col_for_color is not None - and col_for_color not in points.columns - and col_for_color not in sdata_filt[table_name].obs.columns - ): - # no error in case there are multiple elements, but onyl some have color key - msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors." - logger.warning(msg) - elif col_for_color is None or (table_name is not None and col_for_color in sdata_filt[table_name].obs.columns): - points = points[coords].compute() - if ( - col_for_color - and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O" - and not _is_coercable_to_float(color_col) - ): - warnings.warn( - f"Converting copy of '{col_for_color}' column to categorical dtype for categorical " - f"plotting. Consider converting before plotting.", - UserWarning, - stacklevel=2, - ) - sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category") - else: - coords += [col_for_color] - points = points[coords].compute() + points = points[coords].compute() if render_params.groups[index][0] is not None and col_for_color is not None: points = points[points[col_for_color].isin(render_params.groups[index])] - # we construct an anndata to hack the plotting functions - if table_name is None: - adata = AnnData( - X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype - ) - else: - adata = AnnData( - X=points[["x", "y"]].values, obs=sdata_filt[table_name].obs, dtype=points[["x", "y"]].values.dtype - ) - sdata_filt[table_name] = adata - - # we can do this because of dealing with a copy - - # Convert back to dask dataframe to modify sdata - points = dask.dataframe.from_pandas(points, npartitions=1) - sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"}) - - if col_for_color is not None: - cols = sc.get.obs_df(adata, col_for_color) - # maybe set color based on type - if isinstance(cols.dtype, pd.CategoricalDtype): - _maybe_set_colors( - source=adata, - target=adata, - key=col_for_color, - palette=render_params.palette[index] if render_params.palette[index][0] is not None else None, - ) - # when user specified a single color, we overwrite na with it default_color = ( render_params.color[index] @@ -317,9 +265,15 @@ def _render_points( trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData norm = copy(render_params.cmap_params.norm) + + if type(color_vector[0]) == str: + color_vector = pd.Categorical(color_vector) + category_map = {category: i for i, category in enumerate(color_vector.categories)} + color_vector = color_vector.map(category_map) + _cax = ax.scatter( - adata[:, 0].X.flatten(), - adata[:, 1].X.flatten(), + points["x"], + points["y"], s=render_params.size, c=color_vector, rasterized=sc_settings._vector_friendly, @@ -341,7 +295,8 @@ def _render_points( ax=ax, cax=cax, fig_params=fig_params, - adata=adata, + # adata=adata, + adata=None, value_to_plot=render_params.col_for_color, color_source_vector=color_source_vector, palette=palette,