Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_set_color_source_vec and _render_points Refactor #245

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 12 additions & 57 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_normalize,
_rasterize_if_necessary,
_set_color_source_vec,
_get_palette,
to_hex,
)

Expand Down Expand Up @@ -230,64 +231,11 @@ def _render_points(
table_name = element_table_mapping.get(e)

coords = ["x", "y"]
# if col_for_color is not None:
if (
col_for_color is not None
and col_for_color not in points.columns
and col_for_color not in sdata_filt[table_name].obs.columns
):
# no error in case there are multiple elements, but onyl some have color key
msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors."
logger.warning(msg)
elif col_for_color is None or (table_name is not None and col_for_color in sdata_filt[table_name].obs.columns):
points = points[coords].compute()
if (
col_for_color
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
and not _is_coercable_to_float(color_col)
):
warnings.warn(
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical "
f"plotting. Consider converting before plotting.",
UserWarning,
stacklevel=2,
)
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
else:
coords += [col_for_color]
points = points[coords].compute()
points = points[coords].compute()

if render_params.groups[index][0] is not None and col_for_color is not None:
points = points[points[col_for_color].isin(render_params.groups[index])]

# we construct an anndata to hack the plotting functions
if table_name is None:
adata = AnnData(
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
)
else:
adata = AnnData(
X=points[["x", "y"]].values, obs=sdata_filt[table_name].obs, dtype=points[["x", "y"]].values.dtype
)
sdata_filt[table_name] = adata

# we can do this because of dealing with a copy

# Convert back to dask dataframe to modify sdata
points = dask.dataframe.from_pandas(points, npartitions=1)
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})

if col_for_color is not None:
cols = sc.get.obs_df(adata, col_for_color)
# maybe set color based on type
if isinstance(cols.dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=adata,
target=adata,
key=col_for_color,
palette=render_params.palette[index] if render_params.palette[index][0] is not None else None,
)

# when user specified a single color, we overwrite na with it
default_color = (
render_params.color[index]
Expand Down Expand Up @@ -317,9 +265,15 @@ def _render_points(
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData

norm = copy(render_params.cmap_params.norm)

if type(color_vector[0]) == str:
color_vector = pd.Categorical(color_vector)
category_map = {category: i for i, category in enumerate(color_vector.categories)}
color_vector = color_vector.map(category_map)

_cax = ax.scatter(
adata[:, 0].X.flatten(),
adata[:, 1].X.flatten(),
points["x"],
points["y"],
s=render_params.size,
c=color_vector,
rasterized=sc_settings._vector_friendly,
Expand All @@ -341,7 +295,8 @@ def _render_points(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=adata,
# adata=adata,
adata=None,
value_to_plot=render_params.col_for_color,
color_source_vector=color_source_vector,
palette=palette,
Expand Down
15 changes: 2 additions & 13 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,28 +634,17 @@ def _set_color_source_vec(
color = np.full(len(element), to_hex(na_color)) # type: ignore[arg-type]
return color, color, False

model = get_model(sdata[element_name])

# Figure out where to get the color from
origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
if model == PointsModel and table_name is not None:
origin = _locate_points_value_in_table(
value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name
)
if origin is not None:
origins.append(origin)

if len(origins) > 1:
raise ValueError(
f"Color key '{value_to_plot}' for element '{element_name}' been found in multiple locations: {origins}."
)

if len(origins) == 1:
if model == PointsModel and table_name is not None:
color_source_vector = get_values_point_table(sdata=sdata, origin=origin, table_name=table_name)
else:
vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
color_source_vector = vals[value_to_plot]
vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
color_source_vector = vals[value_to_plot]

# numerical case, return early
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
Expand Down