Skip to content

Commit

Permalink
Merge pull request #198 from ianhi/scatter-kwargs
Browse files Browse the repository at this point in the history
handle all scatter kwargs
  • Loading branch information
ianhi authored Jun 2, 2021
2 parents 47542f6 + 28aa8e3 commit 74966bc
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
- uses: actions/setup-python@v2
- uses: psf/black@stable
with:
black_args: ". --check --line-length=100"
options: ". --check --line-length=100"
codespell:
name: Check for spelling errors
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion mpl_interactions/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
version_info = (0, 17, 11)
version_info = (0, 17, 12)
__version__ = ".".join(map(str, version_info))
2 changes: 2 additions & 0 deletions mpl_interactions/mpl_kwargs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from matplotlib.artist import ArtistInspector
from matplotlib.collections import Collection
from matplotlib.image import AxesImage

# this is a list of options to Line2D partially taken from
Expand Down Expand Up @@ -61,6 +62,7 @@
]

imshow_kwargs_list = ArtistInspector(AxesImage).get_setters()
collection_kwargs_list = ArtistInspector(Collection).get_setters()

Text_kwargs_list = [
"agg_filter",
Expand Down
51 changes: 43 additions & 8 deletions mpl_interactions/pyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from matplotlib.colors import to_rgba_array
from matplotlib.patches import Rectangle
from matplotlib.pyplot import sca
import matplotlib.markers as mmarkers

from .controller import Controls, gogogo_controls, prep_scalars

Expand All @@ -23,7 +24,13 @@
notebook_backend,
update_datalim_from_bbox,
)
from .mpl_kwargs import Line2D_kwargs_list, imshow_kwargs_list, Text_kwargs_list, kwarg_popper
from .mpl_kwargs import (
Line2D_kwargs_list,
imshow_kwargs_list,
Text_kwargs_list,
collection_kwargs_list,
kwarg_popper,
)

__all__ = [
"interactive_plot",
Expand Down Expand Up @@ -403,7 +410,9 @@ def interactive_scatter(
vmin=None,
vmax=None,
alpha=None,
marker=None,
edgecolors=None,
facecolors=None,
label=None,
parametric=False,
ax=None,
Expand All @@ -424,18 +433,21 @@ def interactive_scatter(
x, y : function or float or array-like
shape (n, ) for array-like. Functions must return the correct shape as well. If y is None
then parametric must be True and the function for x must return x, y
c : array-like or list of colors or color, broadcastable
Must be broadcastable to x,y and any other plotting kwargs.
valid input to plt.scatter
c : array-like or list of colors or color or Callable
Valid input to plt.scatter or a function
s : float, array-like, function, or index controls object
valid input to plt.scatter, or a function
alpha : float, None, or function(s), broadcastable
Affects all scatter points. This will compound with any alpha introduced by
the ``c`` argument
edgecolors : colorlike, broadcastable
marker : MarkerStyle, or Callable, optional
The marker style or a function returning marker styles.
edgecolor[s] : callable or valid argument to scatter
passed through to scatter.
label : string(s) broadcastable
labels for the functions being plotted.
facecolor[s] : callable or valid argument to scatter
Valid input to plt.scatter, or a function
label : string
Passed through to Matplotlib
parametric : boolean
If True then the function expects to have only received a value for y and that that function will
return an array for both x and y, or will return an array with shape (N, 2)
Expand Down Expand Up @@ -483,15 +495,22 @@ def interactive_scatter(
else:
stretch_y = False

# yanked from https://github.com/matplotlib/matplotlib/blob/bcc1ce8461f5b6e874baaaa02ef776d0243a4abe/lib/matplotlib/axes/_axes.py#L4271-L4273
facecolors = kwargs.pop("facecolor", facecolors)
edgecolors = kwargs.pop("edgecolor", edgecolors)

kwargs, collection_kwargs = kwarg_popper(kwargs, collection_kwargs_list)

ipympl = notebook_backend()
fig, ax = gogogo_figure(ipympl, ax)
use_ipywidgets = ipympl or force_ipywidgets
slider_formats = create_slider_format_dict(slider_formats)

extra_ctrls = []
funcs, extra_ctrls, param_excluder = prep_scalars(kwargs, s=s, alpha=alpha)
funcs, extra_ctrls, param_excluder = prep_scalars(kwargs, s=s, alpha=alpha, marker=marker)
s = funcs["s"]
alpha = funcs["alpha"]
marker = funcs["marker"]

controls, params = gogogo_controls(
kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls
Expand All @@ -509,7 +528,16 @@ def update(params, indices, cache):
c_ = check_callable_xy(c, x_, y_, param_excluder(params), cache)
s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), cache)
ec_ = check_callable_xy(edgecolors, x_, y_, param_excluder(params), cache)
fc_ = check_callable_xy(facecolors, x_, y_, param_excluder(params), cache)
a_ = check_callable_alpha(alpha, param_excluder(params, "alpha"), cache)
marker_ = callable_else_value_no_cast(marker, param_excluder(params), cache)

if marker_ is not None:
if not isinstance(marker_, mmarkers.MarkerStyle):
marker_ = mmarkers.MarkerStyle(marker_)
path = marker_.get_path().transformed(marker_.get_transform())
scatter.set_paths((path,))

if c_ is not None:
try:
c_ = to_rgba_array(c_)
Expand All @@ -524,6 +552,8 @@ def update(params, indices, cache):
scatter.set_facecolor(c_)
if ec_ is not None:
scatter.set_edgecolor(ec_)
if fc_ is not None:
scatter.set_facecolor(c_)
if s_ is not None:
if isinstance(s_, Number):
s_ = np.broadcast_to(s_, (len(x_),))
Expand Down Expand Up @@ -565,7 +595,9 @@ def check_callable_alpha(alpha_, params, cache):
c_ = check_callable_xy(c, x_, y_, p, {})
s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), {})
ec_ = check_callable_xy(edgecolors, x_, y_, p, {})
fc_ = check_callable_xy(facecolors, x_, y_, p, {})
a_ = check_callable_alpha(alpha, params, {})
marker_ = callable_else_value_no_cast(marker, p, {})
scatter = ax.scatter(
x_,
y_,
Expand All @@ -574,9 +606,12 @@ def check_callable_alpha(alpha_, params, cache):
vmin=vmin,
vmax=vmax,
cmap=cmap,
marker=marker_,
alpha=a_,
edgecolors=ec_,
facecolors=fc_,
label=label,
**collection_kwargs,
)
# this is necessary to make calls to plt.colorbar behave as expected
sca(ax)
Expand Down

0 comments on commit 74966bc

Please sign in to comment.