Skip to content

Commit

Permalink
Handle voltage via instance specific copy. (#1532)
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc authored Oct 29, 2024
1 parent 4560899 commit 31a6c48
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ std::string CodegenNeuronCppVisitor::global_variable_name(const SymbolType& symb

std::string CodegenNeuronCppVisitor::get_variable_name(const std::string& name,
bool use_instance) const {
const std::string& varname = update_if_ion_variable_name(name);
std::string varname = update_if_ion_variable_name(name);
if (!info.artificial_cell && varname == "v") {
varname = naming::VOLTAGE_UNUSED_VARIABLE;
}

auto name_comparator = [&varname](const auto& sym) { return varname == get_name(sym); };

Expand Down Expand Up @@ -956,9 +959,6 @@ void CodegenNeuronCppVisitor::print_sdlists_init(bool /* print_initializers */)

CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::functor_params() {
auto params = internal_method_parameters();
if (!info.artificial_cell) {
params.push_back({"", "double", "", "v"});
}

return params;
}
Expand Down Expand Up @@ -1822,7 +1822,7 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {
printer->add_line("auto* _ppvar = _ml_arg->pdata[id];");
if (!info.artificial_cell) {
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
printer->add_line("inst.v_unused[id] = node_data.node_voltages[node_id];");
}

print_rename_state_vars();
Expand Down Expand Up @@ -2069,7 +2069,9 @@ void CodegenNeuronCppVisitor::print_nrn_state() {
printer->push_block("for (int id = 0; id < nodecount; id++)");
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto* _ppvar = _ml_arg->pdata[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
if (!info.artificial_cell) {
printer->add_line("inst.v_unused[id] = node_data.node_voltages[node_id];");
}

/**
* \todo Eigen solver node also emits IonCurVar variable in the functor
Expand Down Expand Up @@ -2142,6 +2144,7 @@ void CodegenNeuronCppVisitor::print_nrn_current(const BreakpointBlock& node) {
printer->fmt_push_block("static inline double nrn_current_{}({})",
info.mod_suffix,
get_parameter_str(args));
printer->add_line("inst.v_unused[id] = v;");
printer->add_line("double current = 0.0;");
print_statement_block(*block, false, false);
for (auto& current: info.currents) {
Expand Down
18 changes: 18 additions & 0 deletions test/usecases/voltage/accessors.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
NEURON {
SUFFIX accessors
NONSPECIFIC_CURRENT il
}

ASSIGNED {
v
il
}

BREAKPOINT {
il = 0.003
}


FUNCTION get_voltage() {
get_voltage = v
}
17 changes: 17 additions & 0 deletions test/usecases/voltage/ode.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
NEURON {
SUFFIX ode
NONSPECIFIC_CURRENT il
}

ASSIGNED {
il
v
}

FUNCTION voltage() {
voltage = 0.001 * v
}

BREAKPOINT {
il = voltage()
}
31 changes: 31 additions & 0 deletions test/usecases/voltage/state_ode.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
NEURON {
SUFFIX state_ode
NONSPECIFIC_CURRENT il
}

STATE {
X
}

ASSIGNED {
il
v
}

INITIAL {
X = v
}

BREAKPOINT {
SOLVE eqn
il = 0.001 * X
}

NONLINEAR eqn { LOCAL c
c = rate()
~ X = c
}

FUNCTION rate() {
rate = v
}
55 changes: 55 additions & 0 deletions test/usecases/voltage/test_voltage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from neuron import h, gui

import numpy as np


def test_voltage_access():
s = h.Section()
s.insert("accessors")

h.finitialize()
v = s(0.5).v
vinst = s(0.5).accessors.get_voltage()
# The voltage will be consistent right after
# finitialize.
assert vinst == v

for _ in range(4):
v = s(0.5).v
h.fadvance()
vinst = s(0.5).accessors.get_voltage()

# During timestepping the internal copy
# of the voltage lags behind the current
# voltage by some timestep.
assert vinst == v, f"{vinst = }, {v = }, delta = {vinst - v}"


def check_ode(mech_name, step):
s = h.Section()
s.insert(mech_name)

h.finitialize()

c = -0.001 / 1e-3

for _ in range(4):
v_expected = step(s(0.5).v, c)
h.fadvance()
np.testing.assert_approx_equal(s(0.5).v, v_expected, significant=10)


def test_breakpoint():
# Results in backward Euler.
check_ode("ode", lambda v, c: (1.0 - c * h.dt) ** (-1.0) * v)


def test_state():
# Effectively, the timing when states are computed results in backward Euler.
check_ode("state_ode", lambda v, c: (1.0 + c * h.dt) * v)


if __name__ == "__main__":
test_voltage_access()
test_breakpoint()
test_state()

0 comments on commit 31a6c48

Please sign in to comment.