Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get voltage via prop->node. #1414

Merged
merged 8 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@
printer->push_block();
printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml, _type};");
printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!info.artificial_cell) {
printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml);", info.mod_suffix);
}
if (!codegen_thread_variables.empty()) {
printer->fmt_line("auto _thread_vars = {}(_thread[{}].get<double*>());",
thread_variables_struct(),
Expand Down Expand Up @@ -251,7 +254,9 @@
printer->fmt_line("int ret_{} = 0;", name);
}

printer->fmt_line("auto v = inst.{}[id];", naming::VOLTAGE_UNUSED_VARIABLE);
if (!info.artificial_cell) {
printer->add_line("auto v = node_data.node_voltages[node_data.nodeindices[id]];");
}

print_statement_block(*node.get_statement_block(), false, false);
printer->fmt_line("return ret_{};", name);
Expand Down Expand Up @@ -294,6 +299,8 @@
Datum* _thread;
NrnThread* nt;
)CODE");

std::string prop_name;
if (info.point_process) {
printer->add_multi_line(R"CODE(
auto* const _pnt = static_cast<Point_process*>(_vptr);
Expand All @@ -307,6 +314,8 @@
_thread = _extcall_thread.data();
nt = static_cast<NrnThread*>(_pnt->_vnt);
)CODE");

prop_name = "_p";
} else if (wrapper_type == InterpreterWrapper::HOC) {
if (program_symtab->lookup(block_name)->has_all_properties(NmodlType::use_range_ptr_var)) {
printer->push_block("if (!_prop_id)");
Expand All @@ -328,16 +337,22 @@
_thread = _extcall_thread.data();
nt = nrn_threads;
)CODE");
prop_name = "_local_prop";
} else { // wrapper_type == InterpreterWrapper::Python
printer->add_multi_line(R"CODE(
_nrn_mechanism_cache_instance _lmc{_prop};
size_t const id{};
size_t const id = 0;
_ppvar = _nrn_mechanism_access_dparam(_prop);
_thread = _extcall_thread.data();
nt = nrn_threads;
)CODE");
prop_name = "_prop";
}

printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!info.artificial_cell) {
printer->fmt_line("auto node_data = make_node_data_{}({});", info.mod_suffix, prop_name);
}
if (!codegen_thread_variables.empty()) {
printer->fmt_line("auto _thread_vars = {}(_thread[{}].get<double*>());",
thread_variables_struct(),
Expand Down Expand Up @@ -415,6 +430,9 @@
ParamVector params;
params.emplace_back("", "_nrn_mechanism_cache_range&", "", "_lmc");
params.emplace_back("", fmt::format("{}&", instance_struct()), "", "inst");
if (!info.artificial_cell) {
params.emplace_back("", fmt::format("{}&", node_data_struct()), "", "node_data");
}
params.emplace_back("", "size_t", "", "id");
params.emplace_back("", "Datum*", "", "_ppvar");
params.emplace_back("", "Datum*", "", "_thread");
Expand Down Expand Up @@ -804,10 +822,6 @@
// TODO implement these when needed.
}

if (!info.vectorize && !info.top_local_variables.empty()) {
throw std::runtime_error("Not implemented, global vectorize something.");
}

if (!info.thread_variables.empty()) {
size_t prefix_sum = 0;
for (size_t i = 0; i < info.thread_variables.size(); ++i) {
Expand All @@ -834,6 +848,14 @@
}
}

if (!info.vectorize && !info.top_local_variables.empty()) {
for (size_t i = 0; i < info.top_local_variables.size(); ++i) {
const auto& var = info.top_local_variables[i];
codegen_global_variables.push_back(var);
}
}


if (!codegen_thread_variables.empty()) {
if (!info.vectorize) {
// MOD files that aren't "VECTORIZED" don't have thread data.
Expand Down Expand Up @@ -1266,7 +1288,7 @@

void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_initializers) {
auto const value_initialize = print_initializers ? "{}" : "";
auto int_type = default_int_data_type();

Check warning on line 1291 in src/codegen/codegen_neuron_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "glibc_asserts": "ON", "os": "ubuntu-22.04" }

unused variable ‘int_type’ [-Wunused-variable]

Check warning on line 1291 in src/codegen/codegen_neuron_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "os": "ubuntu-22.04", "sanitizer": "undefined" }

unused variable 'int_type' [-Wunused-variable]
printer->add_newline(2);
printer->add_line("/** all mechanism instance variables and global variables */");
printer->fmt_push_block("struct {} ", instance_struct());
Expand Down Expand Up @@ -1385,6 +1407,26 @@

printer->pop_block(";");
printer->pop_block();


printer->fmt_push_block("static {} make_node_data_{}(Prop * _prop)",
node_data_struct(),
info.mod_suffix);
printer->add_line("static std::vector<int> node_index{0};");
printer->add_line("Node* _node = _nrn_mechanism_access_node(_prop);");

make_node_data_args = {"node_index.data()",
"&_nrn_mechanism_access_voltage(_node)",
"&_nrn_mechanism_access_d(_node)",
"&_nrn_mechanism_access_rhs(_node)",
"1"};

printer->fmt_push_block("return {}", node_data_struct());
printer->add_multi_line(fmt::format("{}", fmt::join(make_node_data_args, ",\n")));

printer->pop_block(";");
printer->pop_block();
printer->add_newline();
}

void CodegenNeuronCppVisitor::print_thread_variables_structure(bool print_initializers) {
Expand Down Expand Up @@ -1475,7 +1517,6 @@
if (!info.artificial_cell) {
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
printer->fmt_line("inst.{}[id] = v;", naming::VOLTAGE_UNUSED_VARIABLE);
}

print_rename_state_vars();
Expand Down Expand Up @@ -1573,7 +1614,7 @@
}
const auto& var_name = var->get_name();
auto var_pos = position_of_float_var(var_name);
double var_value = var->get_value() == nullptr ? 0.0 : *var->get_value();

Check warning on line 1617 in src/codegen/codegen_neuron_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "glibc_asserts": "ON", "os": "ubuntu-22.04" }

unused variable ‘var_value’ [-Wunused-variable]

Check warning on line 1617 in src/codegen/codegen_neuron_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "os": "ubuntu-22.04", "sanitizer": "undefined" }

unused variable 'var_value' [-Wunused-variable]

printer->fmt_line("_lmc.template fpfield<{}>(_iml) = {}; /* {} */",
var_pos,
Expand Down Expand Up @@ -2164,6 +2205,9 @@
printer->add_line("auto * _ppvar = _nrn_mechanism_access_dparam(_pnt->prop);");

printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!info.artificial_cell) {
printer->fmt_line("auto node_data = make_node_data_{}(_pnt->prop);", info.mod_suffix);
}
printer->fmt_line("// nocmodl has a nullptr dereference for thread variables.");
printer->fmt_line("// NMODL will fail to compile at a later point, because of");
printer->fmt_line("// missing '_thread_vars'.");
Expand Down
34 changes: 34 additions & 0 deletions test/usecases/function/artificial_functions.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
NEURON {
ARTIFICIAL_CELL art_functions
RANGE x
GLOBAL gbl
}

ASSIGNED {
gbl
v
x
}

FUNCTION x_plus_a(a) {
x_plus_a = x + a
}

FUNCTION identity(v) {
identity = v
}

INITIAL {
x = 1.0
gbl = 42.0
}

: A LINEAR block makes a MOD file not VECTORIZED.
STATE {
z
}

LINEAR lin {
~ z = 2
}

38 changes: 38 additions & 0 deletions test/usecases/function/non_threadsafe.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
NEURON {
SUFFIX non_threadsafe
RANGE x
GLOBAL gbl
}

ASSIGNED {
gbl
v
x
}

FUNCTION x_plus_a(a) {
x_plus_a = x + a
}

FUNCTION v_plus_a(a) {
v_plus_a = v + a
}

FUNCTION identity(v) {
identity = v
}

INITIAL {
x = 1.0
gbl = 42.0
}

: A LINEAR block makes a MOD file not VECTORIZED.
STATE {
z
}

LINEAR lin {
~ z = 2
}

38 changes: 38 additions & 0 deletions test/usecases/function/point_non_threadsafe.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
NEURON {
POINT_PROCESS point_non_threadsafe
RANGE x
GLOBAL gbl
}

ASSIGNED {
gbl
v
x
}

FUNCTION x_plus_a(a) {
x_plus_a = x + a
}

FUNCTION v_plus_a(a) {
v_plus_a = v + a
}

FUNCTION identity(v) {
identity = v
}

INITIAL {
x = 1.0
gbl = 42.0
}

: A LINEAR block makes a MOD file not VECTORIZED.
STATE {
z
}

LINEAR lin {
~ z = 2
}

22 changes: 15 additions & 7 deletions test/usecases/function/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from neuron import h


def check_functions(get_instance):
def check_callable(get_instance, has_voltage=True):
for x, value in zip(coords, values):
get_instance(x).x = value

Expand All @@ -19,22 +19,30 @@ def check_functions(get_instance):
actual = get_instance(x).identity(expected)
assert actual == expected, f"{actual} == {expected}"

# Check `f` using `v`.
expected = -2.0
actual = get_instance(x).v_plus_a(40.0)
assert actual == expected, f"{actual} == {expected}"
if has_voltage:
# Check `f` using `v`.
expected = -2.0
actual = get_instance(x).v_plus_a(40.0)
assert actual == expected, f"{actual} == {expected}"


nseg = 5
s = h.Section()
s.nseg = nseg

s.insert("functions")
s.insert("non_threadsafe")

coords = [(0.5 + k) * 1.0 / nseg for k in range(nseg)]
values = [0.1 + k for k in range(nseg)]

point_processes = {x: h.point_functions(s(x)) for x in coords}
point_non_threadsafe = {x: h.point_non_threadsafe(s(x)) for x in coords}

art_cells = {x: h.art_functions() for x in coords}

check_functions(lambda x: s(x).functions)
check_functions(lambda x: point_processes[x])
check_callable(lambda x: s(x).functions)
check_callable(lambda x: s(x).non_threadsafe)
check_callable(lambda x: point_processes[x])
check_callable(lambda x: point_non_threadsafe[x])
check_callable(lambda x: art_cells[x], has_voltage=False)
43 changes: 43 additions & 0 deletions test/usecases/global/non_threadsafe.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
NEURON {
SUFFIX non_threadsafe
GLOBAL gbl
}

LOCAL top_local

PARAMETER {
parameter = 41.0
}

ASSIGNED {
gbl
}

FUNCTION get_gbl() {
get_gbl = gbl
}

FUNCTION get_top_local() {
get_top_local = top_local
}

FUNCTION get_parameter() {
get_parameter = parameter
}

INITIAL {
gbl = 42.0
top_local = 43.0
}

: A LINEAR block makes the MOD file not thread-safe and not
: vectorized. We don't otherwise care about anything below
: this comment.
STATE {
z
}

LINEAR lin {
~ z = 2
}

33 changes: 33 additions & 0 deletions test/usecases/global/test_non_threadsafe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np

from neuron import h, gui
from neuron.units import ms


def test_non_threadsafe():
nseg = 1

s = h.Section()
s.insert("non_threadsafe")
s.nseg = nseg

h.finitialize()

instance = s(0.5).non_threadsafe

# Check INITIAL values.
assert instance.get_parameter() == 41.0
assert instance.get_gbl() == 42.0
assert instance.get_top_local() == 43.0

# Check reassigning a value. Top LOCAL variables
# are not exposed to HOC/Python.
1uc marked this conversation as resolved.
Show resolved Hide resolved
h.parameter_non_threadsafe = 32.1
h.gbl_non_threadsafe = 33.2

assert instance.get_parameter() == 32.1
assert instance.get_gbl() == 33.2


if __name__ == "__main__":
test_non_threadsafe()
Loading