diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index ae15e0428b..5c8a2dc9ed 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -92,6 +92,7 @@ jobs: run: | python -m pip install --upgrade pip -r nrn_requirements.txt python -m pip install --upgrade -r external/nmodl/requirements.txt + python -m pip install --upgrade -r ci_requirements.txt - name: Set up Python@${{ env.PY_MID_VERSION }} uses: actions/setup-python@v4 @@ -108,6 +109,8 @@ jobs: run: | python -m pip install --upgrade pip -r nrn_requirements.txt python -m pip install --upgrade -r external/nmodl/requirements.txt + python -m pip install --upgrade -r ci_requirements.txt + - name: Build & Test id: build-test diff --git a/.github/workflows/neuron-ci.yml b/.github/workflows/neuron-ci.yml index 54c3777cea..ecc62ed386 100644 --- a/.github/workflows/neuron-ci.yml +++ b/.github/workflows/neuron-ci.yml @@ -144,6 +144,7 @@ jobs: run: | python -m pip install --upgrade pip -r nrn_requirements.txt python -m pip install --upgrade -r external/nmodl/requirements.txt + python -m pip install --upgrade -r ci_requirements.txt - name: Set up Python@${{ env.PY_MAX_VERSION }} uses: actions/setup-python@v4 @@ -155,6 +156,7 @@ jobs: run: | python -m pip install --upgrade pip -r nrn_requirements.txt python -m pip install --upgrade -r external/nmodl/requirements.txt + python -m pip install --upgrade -r ci_requirements.txt - name: Setup MUSIC@${{ env.MUSIC_VERSION }} if: matrix.config.music == 'ON' diff --git a/ci_requirements.txt b/ci_requirements.txt new file mode 100644 index 0000000000..f6bd81c65b --- /dev/null +++ b/ci_requirements.txt @@ -0,0 +1,2 @@ +plotly +ipywidgets>=7.0.0 diff --git a/share/lib/python/neuron/__init__.py b/share/lib/python/neuron/__init__.py index 25bb4b65fd..df12031798 100644 --- a/share/lib/python/neuron/__init__.py +++ b/share/lib/python/neuron/__init__.py @@ -1021,6 +1021,7 @@ def __call__(self, graph, *args, **kwargs): def _get_pyplot_axis3d(fig): """requires matplotlib""" + from . import rxd from matplotlib.pyplot import cm import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D @@ -1087,6 +1088,11 @@ def _do_plot( lines = {} lines_list = [] vals = [] + + if isinstance(variable, rxd.species.Species): + if len(variable.regions) > 1: + raise Exception("Please specify region for the species.") + for sec in sections: all_seg_pts = _segment_3d_pts(sec) for seg, (xs, ys, zs, _, _) in zip(sec, all_seg_pts): @@ -1227,6 +1233,7 @@ def color_to_hex(col): def _do_plot_on_plotly(width=2, color=None, cmap=None): """requires matplotlib for colormaps if not specified explicitly""" import ctypes + from . import rxd import plotly.graph_objects as go class FigureWidgetWithNEURON(go.FigureWidget): @@ -1307,6 +1314,11 @@ def mark(self, segment, marker="or", **kwargs): val_range = hi - lo data = [] + + if isinstance(variable, rxd.species.Species): + if len(variable.regions) > 1: + raise Exception("Please specify region for the species.") + for sec in secs: all_seg_pts = _segment_3d_pts(sec) for seg, (xs, ys, zs, _, _) in zip(sec, all_seg_pts): diff --git a/test/rxd/test_pltvar.py b/test/rxd/test_pltvar.py new file mode 100644 index 0000000000..99e3a80791 --- /dev/null +++ b/test/rxd/test_pltvar.py @@ -0,0 +1,67 @@ +import pytest +import plotly +from neuron import units +from matplotlib import pyplot + + +def test_plt_variable(neuron_instance): + """Test to make sure species with multiple regions is not plotted""" + + h, rxd, _, _ = neuron_instance + + dend1 = h.Section("dend1") + dend2 = h.Section("dend2") + dend2.connect(dend1(1)) + + dend1.nseg = dend1.L = dend2.nseg = dend2.L = 11 + dend1.diam = dend2.diam = 2 * units.µm + + cyt = rxd.Region(dend1.wholetree(), nrn_region="i") + cyt2 = rxd.Region(dend2.wholetree(), nrn_region="i") + + ca = rxd.Species( + [cyt, cyt2], + name="ca", + charge=2, + initial=0 * units.mM, + d=1 * units.µm**2 / units.ms, + ) + + ca.nodes(dend1(0.5))[0].include_flux(1e-13, units="mmol/ms") + + h.finitialize(-65 * units.mV) + h.fadvance() + + ps = h.PlotShape(False) + + # Expecting an error for matplotlib + with pytest.raises(Exception, match="Please specify region for the species."): + ps.variable(ca) + ps.plot(pyplot) + + # Expecting an error for plotly + with pytest.raises(Exception, match="Please specify region for the species."): + ps.variable(ca) + ps.plot(plotly) + + cb = rxd.Species( + [cyt], + name="cb", + charge=2, + initial=0 * units.mM, + d=1 * units.µm**2 / units.ms, + ) + + # Scenarios that should work + ps.variable(ca[cyt]) + ps.plot(plotly) # No error expected here + + ps.variable(ca[cyt]) + ps.plot(pyplot) # No error expected here + + # Test plotting with only one region + ps.variable(cb) + ps.plot(plotly) # No Error expected here + + ps.variable(cb) + ps.plot(pyplot) # No Error expected here