Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update __init__.py to catch attempts to plot species instead of speci… #2468

Merged
merged 13 commits into from
Oct 20, 2023
3 changes: 3 additions & 0 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
run: |
python -m pip install --upgrade pip -r nrn_requirements.txt
python -m pip install --upgrade -r external/nmodl/requirements.txt
python -m pip install --upgrade -r ci_requirements.txt

- name: Set up Python@${{ env.PY_MID_VERSION }}
uses: actions/setup-python@v4
Expand All @@ -108,6 +109,8 @@ jobs:
run: |
python -m pip install --upgrade pip -r nrn_requirements.txt
python -m pip install --upgrade -r external/nmodl/requirements.txt
python -m pip install --upgrade -r ci_requirements.txt


- name: Build & Test
id: build-test
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/neuron-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ jobs:
run: |
python -m pip install --upgrade pip -r nrn_requirements.txt
python -m pip install --upgrade -r external/nmodl/requirements.txt
python -m pip install --upgrade -r ci_requirements.txt

- name: Set up Python@${{ env.PY_MAX_VERSION }}
uses: actions/setup-python@v4
Expand All @@ -155,6 +156,7 @@ jobs:
run: |
python -m pip install --upgrade pip -r nrn_requirements.txt
python -m pip install --upgrade -r external/nmodl/requirements.txt
python -m pip install --upgrade -r ci_requirements.txt

- name: Setup MUSIC@${{ env.MUSIC_VERSION }}
if: matrix.config.music == 'ON'
Expand Down
2 changes: 2 additions & 0 deletions ci_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
plotly
ipywidgets>=7.0.0
12 changes: 12 additions & 0 deletions share/lib/python/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ def __call__(self, graph, *args, **kwargs):

def _get_pyplot_axis3d(fig):
"""requires matplotlib"""
from . import rxd
from matplotlib.pyplot import cm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
Expand Down Expand Up @@ -1087,6 +1088,11 @@ def _do_plot(
lines = {}
lines_list = []
vals = []

if isinstance(variable, rxd.species.Species):
if len(variable.regions) > 1:
raise Exception("Please specify region for the species.")

for sec in sections:
all_seg_pts = _segment_3d_pts(sec)
for seg, (xs, ys, zs, _, _) in zip(sec, all_seg_pts):
Expand Down Expand Up @@ -1227,6 +1233,7 @@ def color_to_hex(col):
def _do_plot_on_plotly(width=2, color=None, cmap=None):
"""requires matplotlib for colormaps if not specified explicitly"""
import ctypes
from . import rxd
import plotly.graph_objects as go

class FigureWidgetWithNEURON(go.FigureWidget):
Expand Down Expand Up @@ -1307,6 +1314,11 @@ def mark(self, segment, marker="or", **kwargs):
val_range = hi - lo

data = []

if isinstance(variable, rxd.species.Species):
if len(variable.regions) > 1:
raise Exception("Please specify region for the species.")

for sec in secs:
all_seg_pts = _segment_3d_pts(sec)
for seg, (xs, ys, zs, _, _) in zip(sec, all_seg_pts):
Expand Down
67 changes: 67 additions & 0 deletions test/rxd/test_pltvar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
import plotly
from neuron import units
from matplotlib import pyplot


def test_plt_variable(neuron_instance):
"""Test to make sure species with multiple regions is not plotted"""

h, rxd, _, _ = neuron_instance

dend1 = h.Section("dend1")
dend2 = h.Section("dend2")
dend2.connect(dend1(1))

dend1.nseg = dend1.L = dend2.nseg = dend2.L = 11
dend1.diam = dend2.diam = 2 * units.µm

cyt = rxd.Region(dend1.wholetree(), nrn_region="i")
cyt2 = rxd.Region(dend2.wholetree(), nrn_region="i")

ca = rxd.Species(
[cyt, cyt2],
name="ca",
charge=2,
initial=0 * units.mM,
d=1 * units.µm**2 / units.ms,
)

ca.nodes(dend1(0.5))[0].include_flux(1e-13, units="mmol/ms")

h.finitialize(-65 * units.mV)
h.fadvance()

ps = h.PlotShape(False)

# Expecting an error for matplotlib
with pytest.raises(Exception, match="Please specify region for the species."):
ps.variable(ca)
ps.plot(pyplot)

rgourdine marked this conversation as resolved.
Show resolved Hide resolved
# Expecting an error for plotly
with pytest.raises(Exception, match="Please specify region for the species."):
ps.variable(ca)
ps.plot(plotly)

cb = rxd.Species(
[cyt],
name="cb",
charge=2,
initial=0 * units.mM,
d=1 * units.µm**2 / units.ms,
)

# Scenarios that should work
ps.variable(ca[cyt])
ps.plot(plotly) # No error expected here

ps.variable(ca[cyt])
ps.plot(pyplot) # No error expected here

# Test plotting with only one region
ps.variable(cb)
ps.plot(plotly) # No Error expected here

ps.variable(cb)
ps.plot(pyplot) # No Error expected here
Loading