Skip to content

Commit

Permalink
Codegen for CVODE (#1493)
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran authored Oct 31, 2024
1 parent 800d098 commit 3df34b9
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/codegen/codegen_helper_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ void CodegenHelperVisitor::visit_nrn_state_block(const ast::NrnStateBlock& node)
node.visit_children(*this);
}

void CodegenHelperVisitor::visit_cvode_block(const ast::CvodeBlock& node) {
info.cvode_block = &node;
node.visit_children(*this);
}


void CodegenHelperVisitor::visit_procedure_block(const ast::ProcedureBlock& node) {
info.procedures.push_back(&node);
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen_helper_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor {
void visit_program(const ast::Program& node) override;
void visit_factor_def(const ast::FactorDef& node) override;
void visit_nrn_state_block(const ast::NrnStateBlock& node) override;
void visit_cvode_block(const ast::CvodeBlock& node) override;
void visit_linear_block(const ast::LinearBlock& node) override;
void visit_non_linear_block(const ast::NonLinearBlock& node) override;
void visit_discrete_block(const ast::DiscreteBlock& node) override;
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/codegen_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,9 @@ struct CodegenInfo {
/// nrn_state block
const ast::NrnStateBlock* nrn_state_block = nullptr;

/// the CVODE block
const ast::CvodeBlock* cvode_block = nullptr;

/// net receive block for point process
const ast::NetReceiveBlock* net_receive_node = nullptr;

Expand Down
21 changes: 21 additions & 0 deletions src/codegen/codegen_naming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,27 @@ static constexpr char THREAD_ARGS_PROTO[] = "_threadargsproto_";
/// prefix for ion variable
static constexpr char ION_VARNAME_PREFIX[] = "ion_";

/// name of CVODE method for counting # of ODEs
static constexpr char CVODE_COUNT_NAME[] = "ode_count";

/// name of CVODE method for updating non-stiff systems
static constexpr char CVODE_UPDATE_NON_STIFF_NAME[] = "ode_update_nonstiff";

/// name of CVODE method for updating stiff systems
static constexpr char CVODE_UPDATE_STIFF_NAME[] = "ode_update_stiff";

/// name of CVODE method for setting up non-stiff systems
static constexpr char CVODE_SETUP_NON_STIFF_NAME[] = "ode_setup_nonstiff";

/// name of CVODE method for setting up stiff systems
static constexpr char CVODE_SETUP_STIFF_NAME[] = "ode_setup_stiff";

/// name of CVODE method for setting up tolerances
static constexpr char CVODE_SETUP_TOLERANCES_NAME[] = "ode_setup_tolerances";

/// name of the CVODE variable (can be arbitrary)
static constexpr char CVODE_VARIABLE_NAME[] = "cvode_ieq";

/// commonly used variables in verbatim block and how they
/// should be mapped to new code generation backends
// clang-format off
Expand Down
156 changes: 152 additions & 4 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,13 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in
printer->fmt_line("static Symbol* _{}_sym;", ion.name);
}

if (info.emit_cvode) {
printer->add_line("static Symbol** _atollist;");
printer->push_block("static HocStateTolerance _hoc_state_tol[] =");
printer->add_line("{0, 0}");
printer->pop_block(";");
}

printer->add_line("static int mech_type;");

if (info.point_process) {
Expand Down Expand Up @@ -1422,16 +1429,22 @@ void CodegenNeuronCppVisitor::print_mechanism_register_regular() {
i));
}

if (info.emit_cvode) {
mech_register_args.push_back(
fmt::format("_nrn_mechanism_field<int>{{\"{}\", \"cvodeieq\"}} /* {} */",
naming::CVODE_VARIABLE_NAME,
codegen_int_variables_size));
}

printer->add_multi_line(fmt::format("{}", fmt::join(mech_register_args, ",\n")));

printer->decrease_indent();
printer->add_line(");");
printer->add_newline();


printer->fmt_line("hoc_register_prop_size(mech_type, {}, {});",
float_variables_size(),
int_variables_size());
int_variables_size() + static_cast<int>(info.emit_cvode));

for (int i = 0; i < codegen_int_variables_size; ++i) {
if (i != info.semantics[i].index) {
Expand Down Expand Up @@ -1495,8 +1508,20 @@ void CodegenNeuronCppVisitor::print_mechanism_register_regular() {
printer->fmt_line("{}._morphology_sym = hoc_lookup(\"morphology\");",
global_struct_instance());
}

if (info.emit_cvode) {
printer->fmt_line("hoc_register_dparam_semantics(mech_type, {}, \"cvodeieq\");",
codegen_int_variables_size);
printer->fmt_line("hoc_register_cvode(mech_type, {}, {}, {}, {});",
method_name(naming::CVODE_COUNT_NAME),
method_name(naming::CVODE_SETUP_TOLERANCES_NAME),
method_name(naming::CVODE_SETUP_NON_STIFF_NAME),
method_name(naming::CVODE_SETUP_STIFF_NAME));
printer->fmt_line("hoc_register_tolerance(mech_type, _hoc_state_tol, &_atollist);");
}
}


void CodegenNeuronCppVisitor::print_mechanism_register_nothing() {
printer->add_line("hoc_register_var(hoc_scalar_double, hoc_vector_double, hoc_intfunc);");
}
Expand Down Expand Up @@ -1936,9 +1961,9 @@ void CodegenNeuronCppVisitor::print_nrn_alloc() {
)CODE");
printer->chain_block("else");
}
if (info.semantic_variable_count) {
if (info.semantic_variable_count || info.emit_cvode) {
printer->fmt_line("_ppvar = nrn_prop_datum_alloc(mech_type, {}, _prop);",
info.semantic_variable_count);
info.semantic_variable_count + static_cast<int>(info.emit_cvode));
printer->add_line("_nrn_mechanism_access_dparam(_prop) = _ppvar;");
}
printer->add_multi_line(R"CODE(
Expand Down Expand Up @@ -2384,6 +2409,8 @@ void CodegenNeuronCppVisitor::print_mechanism_variables_macros() {
if (info.table_count > 0) {
printer->add_line("void _nrn_thread_table_reg(int, nrn_thread_table_check_t);");
}
// for CVODE
printer->add_line("extern void _cvode_abstol(Symbol**, double*, int);");
if (info.for_netcon_used) {
printer->add_line("int _nrn_netcon_args(void*, double***);");
}
Expand Down Expand Up @@ -2457,6 +2484,7 @@ void CodegenNeuronCppVisitor::print_codegen_routines_regular() {
print_nrn_alloc();
print_function_prototypes();
print_longitudinal_diffusion_callbacks();
print_cvode_definitions();
print_point_process_function_definitions();
print_setdata_functions();
print_check_table_entrypoint();
Expand Down Expand Up @@ -2605,6 +2633,126 @@ void CodegenNeuronCppVisitor::print_net_receive_common_code() {
printer->add_line("double t = nt->_t;");
}

CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::cvode_setup_parameters() {
return {{"", "const _nrn_model_sorted_token&", "", "_sorted_token"},
{"", "NrnThread*", "", "nt"},
{"", "Memb_list*", "", "_ml_arg"},
{"", "int", "", "_type"}};
}

CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::cvode_update_parameters() {
ParamVector args = {{"", "_nrn_mechanism_cache_range&", "", "_lmc"},
{"", fmt::format("{}_Instance&", info.mod_suffix), "", "inst"},
{"", fmt::format("{}_NodeData&", info.mod_suffix), "", "node_data"},
{"", "size_t", "", "id"},
{"", "Datum*", "", "_ppvar"},
{"", "Datum*", "", "_thread"},
{"", "NrnThread*", "", "nt"}};

if (info.thread_callback_register) {
auto type_name = fmt::format("{}&", thread_variables_struct());
args.emplace_back("", type_name, "", "_thread_vars");
}
return args;
}


void CodegenNeuronCppVisitor::print_cvode_count() {
printer->fmt_push_block("static constexpr int {}(int _type)",
method_name(naming::CVODE_COUNT_NAME));
printer->fmt_line("return {};", info.cvode_block->get_n_odes()->get_value());
printer->pop_block();
printer->add_newline(2);
}

void CodegenNeuronCppVisitor::print_cvode_tolerances() {
const CodegenNeuronCppVisitor::ParamVector tolerances_parameters = {
{"", "Prop*", "", "_prop"},
{"", "int", "", "equation_index"},
{"", "neuron::container::data_handle<double>*", "", "_pv"},
{"", "neuron::container::data_handle<double>*", "", "_pvdot"},
{"", "double*", "", "_atol"},
{"", "int", "", "_type"}};

auto get_param_name = [](const auto& item) { return std::get<3>(item); };

auto prop_name = get_param_name(tolerances_parameters[0]);
auto eqindex_name = get_param_name(tolerances_parameters[1]);
auto pv_name = get_param_name(tolerances_parameters[2]);
auto pvdot_name = get_param_name(tolerances_parameters[3]);
auto atol_name = get_param_name(tolerances_parameters[4]);

printer->fmt_push_block("static void {}({})",
method_name(naming::CVODE_SETUP_TOLERANCES_NAME),
get_parameter_str(tolerances_parameters));
printer->fmt_line("auto* _ppvar = _nrn_mechanism_access_dparam({});", prop_name);
printer->fmt_line("_ppvar[{}].literal_value<int>() = {};", int_variables_size(), eqindex_name);
printer->fmt_push_block("for (int i = 0; i < {}(0); i++)",
method_name(naming::CVODE_COUNT_NAME));
printer->fmt_line("{}[i] = _nrn_mechanism_get_param_handle({}, _slist1[i]);",
pv_name,
prop_name);
printer->fmt_line("{}[i] = _nrn_mechanism_get_param_handle({}, _dlist1[i]);",
pvdot_name,
prop_name);
printer->fmt_line("_cvode_abstol(_atollist, {}, i);", atol_name);
printer->pop_block();
printer->pop_block();
printer->add_newline(2);
}

void CodegenNeuronCppVisitor::print_cvode_update(const std::string& name,
const ast::StatementBlock& block) {
printer->fmt_push_block("static int {}({})",
method_name(name),
get_parameter_str(cvode_update_parameters()));
printer->add_line(
"auto v = node_data.node_voltages ? "
"node_data.node_voltages[node_data.nodeindices[id]] : 0.0;");

print_statement_block(block, false, false);

printer->add_line("return 0;");
printer->pop_block();
printer->add_newline(2);
}

void CodegenNeuronCppVisitor::print_cvode_setup(const std::string& setup_name,
const std::string& update_name) {
printer->fmt_push_block("static void {}({})",
method_name(setup_name),
get_parameter_str(cvode_setup_parameters()));
print_entrypoint_setup_code_from_memb_list();
printer->fmt_line("auto nodecount = _ml_arg->nodecount;");
printer->push_block("for (int id = 0; id < nodecount; id++)");
printer->add_line("auto* _ppvar = _ml_arg->pdata[id];");
printer->add_line(
"auto v = node_data.node_voltages ? "
"node_data.node_voltages[node_data.nodeindices[id]] : 0.0;");
printer->fmt_line("{}({});", method_name(update_name), get_arg_str(cvode_update_parameters()));

printer->pop_block();
printer->pop_block();

printer->add_newline(2);
}

void CodegenNeuronCppVisitor::print_cvode_definitions() {
if (!info.emit_cvode) {
return;
}

printer->add_newline(2);
printer->add_line("/* Functions related to CVODE codegen */");
print_cvode_count();
print_cvode_tolerances();
print_cvode_update(naming::CVODE_UPDATE_NON_STIFF_NAME,
*info.cvode_block->get_non_stiff_block());
print_cvode_update(naming::CVODE_UPDATE_STIFF_NAME, *info.cvode_block->get_stiff_block());
print_cvode_setup(naming::CVODE_SETUP_NON_STIFF_NAME, naming::CVODE_UPDATE_NON_STIFF_NAME);
print_cvode_setup(naming::CVODE_SETUP_STIFF_NAME, naming::CVODE_UPDATE_STIFF_NAME);
}

void CodegenNeuronCppVisitor::print_net_receive() {
printing_net_receive = true;
auto node = info.net_receive_node;
Expand Down
42 changes: 42 additions & 0 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,48 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void print_ion_variable() override;


/**
* Get the parameters for functions that setup (initialize) CVODE
*
*/
ParamVector cvode_setup_parameters();


/**
* Get the parameters for functions that update state at given timestep in CVODE
*
*/
ParamVector cvode_update_parameters();


/**
* Print all callbacks for CVODE
*
*/
void print_cvode_definitions();

/**
* Print the CVODE function returning the number of ODEs to solve
*/
void print_cvode_count();

/**
* Print the CVODE function for setup of tolerances
*/
void print_cvode_tolerances();

/**
* Print the CVODE update function \c name contained in \c block
*/
void print_cvode_update(const std::string& name, const ast::StatementBlock& block);

/**
* Print the CVODE setup function \c setup_name that calls the CVODE update function
* \c update_name
*/
void print_cvode_setup(const std::string& setup_name, const std::string& update_name);


/****************************************************************************************/
/* Overloaded visitor routines */
/****************************************************************************************/
Expand Down
1 change: 1 addition & 0 deletions test/usecases/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(NMODL_USECASE_DIRS
builtin_functions
constant
constructor
cvode
electrode_current
external
function
Expand Down
37 changes: 37 additions & 0 deletions test/usecases/cvode/derivative.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
NEURON {
SUFFIX scalar
}

PARAMETER {
freq = 10
a = 5
v1 = -1
v2 = 5
v3 = 15
v4 = 0.8
v5 = 0.3
r = 3
k = 0.2
}

STATE {var1 var2 var3}

INITIAL {
var1 = v1
var2 = v2
var3 = v3
}

BREAKPOINT {
SOLVE equation METHOD derivimplicit
}


DERIVATIVE equation {
: eq with a function on RHS
var1' = -sin(freq * t)
: simple ODE (nonzero Jacobian)
var2' = -var2 * a
: logistic ODE
var3' = r * var3 * (1 - var3 / k)
}
Loading

0 comments on commit 3df34b9

Please sign in to comment.