Skip to content

Commit

Permalink
fix plot_scatter error with 0 or 1 points (and 2 points with extent) (c…
Browse files Browse the repository at this point in the history
…loses #197)
  • Loading branch information
whitews committed Sep 7, 2024
1 parent 8f3f7c3 commit b1b46b4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
35 changes: 25 additions & 10 deletions src/flowkit/_utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,18 @@ def plot_scatter(

if len(x) > 0:
x_min, x_max = _calculate_extent(x, d_min=x_min, d_max=x_max, pad=0.02)
else:
# empty array, set extents to 0 to avoid errors
x_min = x_max = 0

# turn off color density
color_density = False

if len(y) > 0:
y_min, y_max = _calculate_extent(y, d_min=y_min, d_max=y_max, pad=0.02)
else:
# empty array, set extents to 0 to avoid errors
y_min = y_max = 0

if y_max > x_max:
radius_dimension = 'y'
Expand Down Expand Up @@ -444,7 +454,11 @@ def plot_scatter(
# re-order the highlight indices to match
highlight_mask = highlight_mask[idx]

z_norm = (z - z.min()) / (z.max() - z.min())
# check if z max - z min is 0 (e.g. a single data point)
if z.max() - z.min() == 0:
z_norm = np.zeros(len(x))
else:
z_norm = (z - z.min()) / (z.max() - z.min())
else:
z_norm = np.zeros(len(x))

Expand Down Expand Up @@ -477,15 +491,16 @@ def plot_scatter(
p.xaxis.axis_label = x_label
p.yaxis.axis_label = y_label

p.circle(
x,
y,
radius=radius,
radius_dimension=radius_dimension,
fill_color=z_colors,
fill_alpha=fill_alpha,
line_color=None
)
if len(x) > 0:
p.circle(
x,
y,
radius=radius,
radius_dimension=radius_dimension,
fill_color=z_colors,
fill_alpha=fill_alpha,
line_color=None
)

return p

Expand Down
29 changes: 29 additions & 0 deletions tests/plot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""
import copy
import unittest

import bokeh.models
import numpy as np
from bokeh.plotting import figure as bk_Figure
from bokeh.layouts import GridPlot as bk_GridPlot
import flowkit as fk
Expand All @@ -21,6 +24,32 @@ class PlotTestCase(unittest.TestCase):
pixel-level, this TestCase only tests that plots are returned
from plotting functions.
"""
def test_plot_scatter_zero_points(self):
# from issue #197
arr = np.array([], float)
# noinspection PyProtectedMember
p = fk._utils.plot_utils.plot_scatter(arr, arr)

self.assertIsInstance(p, bk_Figure)

def test_plot_scatter_one_point(self):
# from issue #197
arr = np.array([1., ], float)
# noinspection PyProtectedMember
p = fk._utils.plot_utils.plot_scatter(arr, arr)

self.assertIsInstance(p, bk_Figure)

def test_plot_scatter_two_points_with_extents(self):
# from issue #197
# noinspection PyProtectedMember
p = fk._utils.plot_utils.plot_scatter(
np.array([0.44592386, 0.52033713]),
np.array([0.6131338, 0.60149982]),
x_min=0, x_max=.997, y_min=0, y_max=.991
)

self.assertIsInstance(p, bk_Figure)

def test_sample_plot_histogram(self):
sample = copy.deepcopy(test_sample)
Expand Down

0 comments on commit b1b46b4

Please sign in to comment.