From 417acac77a6c63f706e8e942a410cad7d7b58362 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 24 Sep 2024 18:03:56 +0200 Subject: [PATCH 01/74] Add `DerivativeOriginalFunctionBlock` and `DerivativeVisitor` --- src/language/code_generator.cmake | 1 + src/language/codegen.yaml | 25 +++- src/main.cpp | 8 ++ src/visitors/CMakeLists.txt | 1 + src/visitors/derivative_original_visitor.cpp | 129 +++++++++++++++++++ src/visitors/derivative_original_visitor.hpp | 64 +++++++++ src/visitors/sympy_solver_visitor.cpp | 4 + src/visitors/sympy_solver_visitor.hpp | 2 + 8 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 src/visitors/derivative_original_visitor.cpp create mode 100644 src/visitors/derivative_original_visitor.hpp diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index a3dea0767f..992d5b0cb1 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -74,6 +74,7 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/define.hpp ${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp ${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp ${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 477df7fa65..ac92afb517 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -87,7 +87,30 @@ type: StatementBlock - finalize_block: brief: "Statement block to be executed after calling linear solver" - type: StatementBlock + type: StatementBlock + - DerivativeOriginalFunctionBlock: + nmodl: "DERIVATIVE_ORIGINAL_FUNCTION " + members: + - name: + brief: "Name of the derivative block" + type: Name + node_name: true + suffix: {value: " "} + - statement_block: + brief: "Block with statements vector" + type: StatementBlock + getter: {override: true} + brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + description: | + The original `DERIVATIVE` block in NMODL is + replaced in-place if the system of ODEs is + solvable analytically. Therefore, this + block's sole purpose is to keep the + original, unsolved block in the AST. This is + primarily useful when we need to solve the + ODE system using implicit methods, for + instance, CVode. + - WrappedExpression: brief: "Wrap any other expression type" members: diff --git a/src/main.cpp b/src/main.cpp index f12bfe35dd..f150753479 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,6 +25,7 @@ #include "visitors/after_cvode_to_cnexp_visitor.hpp" #include "visitors/ast_visitor.hpp" #include "visitors/constant_folder_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" #include "visitors/function_callpath_visitor.hpp" #include "visitors/global_var_visitor.hpp" #include "visitors/implicit_argument_visitor.hpp" @@ -497,6 +498,13 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_linear = node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK); const bool sympy_sparse = solver_exists(*ast, "sparse"); + if (neuron_code) { + logger->info("Running derivative visitor"); + DerivativeOriginalVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("derivative_original")); + } + if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || sympy_linear) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() diff --git a/src/visitors/CMakeLists.txt b/src/visitors/CMakeLists.txt index 262b6a623a..ede77671eb 100644 --- a/src/visitors/CMakeLists.txt +++ b/src/visitors/CMakeLists.txt @@ -11,6 +11,7 @@ add_library( visitor STATIC after_cvode_to_cnexp_visitor.cpp constant_folder_visitor.cpp + derivative_original_visitor.cpp defuse_analyze_visitor.cpp function_callpath_visitor.cpp global_var_visitor.cpp diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp new file mode 100644 index 0000000000..bd6d135350 --- /dev/null +++ b/src/visitors/derivative_original_visitor.cpp @@ -0,0 +1,129 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "visitors/derivative_original_visitor.hpp" + +#include "ast/all.hpp" +#include "lexer/token_mapping.hpp" +#include "pybind/pyembed.hpp" +#include "utils/logger.hpp" +#include "visitors/visitor_utils.hpp" +#include +#include + +namespace pywrap = nmodl::pybind_wrappers; + +namespace nmodl { +namespace visitor { + +static int get_index(const ast::IndexedName& node) { + return std::stoi(to_nmodl(node.get_length())); +} + +static auto get_name_map(const ast::Expression& node, const std::string& name) { + std::unordered_map name_map; + // all of the "reserved" symbols + auto reserved_symbols = get_external_functions(); + // all indexed vars + auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); + for (const auto& var: indexed_vars) { + if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && + std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { + return var->get_node_name() == item; + })) { + logger->debug( + "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "node_map", + var->get_node_name()); + name_map[var->get_node_name()] = get_index( + *std::dynamic_pointer_cast(var)); + } + } + return name_map; +} + + +void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { + node.visit_children(*this); + der_block_function = node.clone(); +} + + +void DerivativeOriginalVisitor::visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) { + derivative_block = true; + node_type = node.get_node_type(); + node.visit_children(*this); + node_type = ast::AstNodeType::NODE; + derivative_block = false; +} + +void DerivativeOriginalVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { + differential_equation = true; + node.visit_children(*this); + differential_equation = false; +} + + +void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); + + /// we have to only solve ODEs under original derivative block where lhs is variable + if (!derivative_block || !differential_equation || !lhs->is_var_name()) { + return; + } + + auto name = std::dynamic_pointer_cast(lhs)->get_name(); + + if (name->is_prime_name() || name->is_indexed_name()) { + std::string varname; + if (name->is_prime_name()) { + varname = "D" + name->get_node_name(); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + name->get_node_name(), + varname, + to_nmodl(node)); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } + } else { + varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); + // we discard the RHS here so it can be anything (as long as NMODL considers it valid) + auto statement = fmt::format("{} = {}", varname, varname); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + to_nmodl(node.get_lhs()), + varname, + to_nmodl(node)); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + // TODO add symbol? + } + } +} + +void DerivativeOriginalVisitor::visit_program(ast::Program& node) { + program_symtab = node.get_symbol_table(); + node.visit_children(*this); + if (der_block_function) { + auto der_node = + new ast::DerivativeOriginalFunctionBlock(der_block_function->get_name(), + der_block_function->get_statement_block()); + node.emplace_back_node(der_node); + } + + // re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block + node.visit_children(*this); +} + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp new file mode 100644 index 0000000000..d483ab845b --- /dev/null +++ b/src/visitors/derivative_original_visitor.hpp @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +/** + * \file + * \brief \copybrief nmodl::visitor::DerivativeOriginalVisitor + */ + +#include "symtab/decl.hpp" +#include "visitors/ast_visitor.hpp" +#include + +namespace nmodl { +namespace visitor { + +/** + * \addtogroup visitor_classes + * \{ + */ + +/** + * \class DerivativeOriginalVisitor + * \brief Make a copy of the `DERIVATIVE` block (if it exists), and insert back as + * `DERIVATIVE_ORIGINAL_FUNCTION` block. + * + * If \ref SympySolverVisitor runs successfully, it replaces the original + * solution. This block is inserted before that to prevent losing access to + * information about the block. + */ +class DerivativeOriginalVisitor: public AstVisitor { + private: + /// The copy of the derivative block we are solving + ast::DerivativeBlock* der_block_function = nullptr; + + /// true while visiting differential equation + bool differential_equation = false; + + /// global symbol table + symtab::SymbolTable* program_symtab = nullptr; + + /// visiting derivative block + bool derivative_block = false; + + ast::AstNodeType node_type = ast::AstNodeType::NODE; + + public: + void visit_derivative_block(ast::DerivativeBlock& node) override; + void visit_program(ast::Program& node) override; + void visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) override; + void visit_diff_eq_expression(ast::DiffEqExpression& node) override; + void visit_binary_expression(ast::BinaryExpression& node) override; +}; + +/** \} */ // end of visitor_classes + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index f2d6260c21..e7b955a5c0 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -399,6 +399,10 @@ void SympySolverVisitor::visit_var_name(ast::VarName& node) { } } +// Skip visiting DERIVATIVE_ORIGINAL block +void SympySolverVisitor::visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) {} + void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { const auto& lhs = node.get_expression()->get_lhs(); diff --git a/src/visitors/sympy_solver_visitor.hpp b/src/visitors/sympy_solver_visitor.hpp index ecb326ab63..627451d4b7 100644 --- a/src/visitors/sympy_solver_visitor.hpp +++ b/src/visitors/sympy_solver_visitor.hpp @@ -185,6 +185,8 @@ class SympySolverVisitor: public AstVisitor { void visit_expression_statement(ast::ExpressionStatement& node) override; void visit_statement_block(ast::StatementBlock& node) override; void visit_program(ast::Program& node) override; + void visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) override; }; /** @} */ // end of visitor_classes From d33a594575291e4b7cc1fcb10f76501c17ec6143 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 24 Sep 2024 18:18:33 +0200 Subject: [PATCH 02/74] Remove unused functions --- src/visitors/derivative_original_visitor.cpp | 26 -------------------- 1 file changed, 26 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index bd6d135350..2e7b6942a2 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -20,32 +20,6 @@ namespace pywrap = nmodl::pybind_wrappers; namespace nmodl { namespace visitor { -static int get_index(const ast::IndexedName& node) { - return std::stoi(to_nmodl(node.get_length())); -} - -static auto get_name_map(const ast::Expression& node, const std::string& name) { - std::unordered_map name_map; - // all of the "reserved" symbols - auto reserved_symbols = get_external_functions(); - // all indexed vars - auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); - for (const auto& var: indexed_vars) { - if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && - std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { - return var->get_node_name() == item; - })) { - logger->debug( - "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " - "node_map", - var->get_node_name()); - name_map[var->get_node_name()] = get_index( - *std::dynamic_pointer_cast(var)); - } - } - return name_map; -} - void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); From b9f08d05225b37c9cd288c3471112cda51861bb4 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 09:27:13 +0200 Subject: [PATCH 03/74] Add test for DerivativeOriginalVisitor --- test/unit/CMakeLists.txt | 1 + test/unit/visitor/derivative_original.cpp | 55 +++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 test/unit/visitor/derivative_original.cpp diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 44d57fe91f..9ed95d8aff 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -45,6 +45,7 @@ add_executable( visitor/kinetic_block.cpp visitor/localize.cpp visitor/localrename.cpp + visitor/derivative_original.cpp visitor/local_to_assigned.cpp visitor/lookup.cpp visitor/loop_unroll.cpp diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp new file mode 100644 index 0000000000..d2f5e17cf2 --- /dev/null +++ b/test/unit/visitor/derivative_original.cpp @@ -0,0 +1,55 @@ +#include + +#include "ast/program.hpp" +#include "parser/nmodl_driver.hpp" +#include "test/unit/utils/test_utils.hpp" +#include "visitors/checkparent_visitor.hpp" +#include "visitors/nmodl_visitor.hpp" +#include "visitors/symtab_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" +#include "visitors/visitor_utils.hpp" + +using namespace nmodl; +using namespace visitor; +using namespace test; +using namespace test_utils; + +using nmodl::parser::NmodlDriver; + + +auto run_derivative_original_visitor(const std::string& text) { + NmodlDriver driver; + const auto& ast = driver.parse_string(text); + SymtabVisitor().visit_program(*ast); + DerivativeOriginalVisitor().visit_program(*ast); + + return ast; +} + + +TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { + GIVEN("DERIVATIVE block") { + std::string nmodl_text = R"( + NEURON { + SUFFIX example + } + + STATE {x z[2]} + + DERIVATIVE equation { + x' = -x + z'[0] = x + z'[1] = x + z[0] + } +)"; + auto ast = run_derivative_original_visitor(nmodl_text); + THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { + auto block = collect_nodes(*ast, {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + REQUIRE(!block.empty()); + THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + REQUIRE(primed_vars.empty()); + } + } + } +} From c5dc45e9a3a63e2f4e63996d4254426f6e2f6fa6 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 09:28:40 +0200 Subject: [PATCH 04/74] Fmt --- test/unit/visitor/derivative_original.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index d2f5e17cf2..58c0bd9c4f 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -4,9 +4,9 @@ #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" #include "visitors/checkparent_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" #include "visitors/nmodl_visitor.hpp" #include "visitors/symtab_visitor.hpp" -#include "visitors/derivative_original_visitor.hpp" #include "visitors/visitor_utils.hpp" using namespace nmodl; @@ -28,8 +28,8 @@ auto run_derivative_original_visitor(const std::string& text) { TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { - GIVEN("DERIVATIVE block") { - std::string nmodl_text = R"( + GIVEN("DERIVATIVE block") { + std::string nmodl_text = R"( NEURON { SUFFIX example } @@ -42,14 +42,15 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative z'[1] = x + z[0] } )"; - auto ast = run_derivative_original_visitor(nmodl_text); - THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { - auto block = collect_nodes(*ast, {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); - REQUIRE(!block.empty()); - THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { - auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); - REQUIRE(primed_vars.empty()); - } + auto ast = run_derivative_original_visitor(nmodl_text); + THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { + auto block = collect_nodes(*ast, + {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + REQUIRE(!block.empty()); + THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + REQUIRE(primed_vars.empty()); } } + } } From 1dadd7a21b43b97b9caec984e1fb48ea2cf001db Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 10:44:01 +0200 Subject: [PATCH 05/74] Fix leak --- src/visitors/derivative_original_visitor.cpp | 2 +- src/visitors/derivative_original_visitor.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 2e7b6942a2..9377641851 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -23,7 +23,7 @@ namespace visitor { void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); - der_block_function = node.clone(); + der_block_function = std::shared_ptr(node.clone()); } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index d483ab845b..7178390ca8 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -36,7 +36,7 @@ namespace visitor { class DerivativeOriginalVisitor: public AstVisitor { private: /// The copy of the derivative block we are solving - ast::DerivativeBlock* der_block_function = nullptr; + std::shared_ptr der_block_function = nullptr; /// true while visiting differential equation bool differential_equation = false; From 1125fdf43c886b9250faabd7dbb55241f610ecac Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 11:32:07 +0200 Subject: [PATCH 06/74] Remove unused stuff `DERIVATIVE` blocks can't have array variables in NOCMODL by default, so let's go with that. --- src/visitors/derivative_original_visitor.cpp | 41 ++++++-------------- src/visitors/derivative_original_visitor.hpp | 2 - test/unit/visitor/derivative_original.cpp | 7 ++-- 3 files changed, 14 insertions(+), 36 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 9377641851..43b3eb6df8 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -30,9 +30,7 @@ void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& nod void DerivativeOriginalVisitor::visit_derivative_original_function_block( ast::DerivativeOriginalFunctionBlock& node) { derivative_block = true; - node_type = node.get_node_type(); node.visit_children(*this); - node_type = ast::AstNodeType::NODE; derivative_block = false; } @@ -53,34 +51,17 @@ void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& n auto name = std::dynamic_pointer_cast(lhs)->get_name(); - if (name->is_prime_name() || name->is_indexed_name()) { - std::string varname; - if (name->is_prime_name()) { - varname = "D" + name->get_node_name(); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", - name->get_node_name(), - varname, - to_nmodl(node)); - node.set_lhs(std::make_shared(new ast::String(varname))); - if (program_symtab->lookup(varname) == nullptr) { - auto symbol = std::make_shared(varname, ModToken()); - symbol->set_original_name(name->get_node_name()); - program_symtab->insert(symbol); - } - } else { - varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); - // we discard the RHS here so it can be anything (as long as NMODL considers it valid) - auto statement = fmt::format("{} = {}", varname, varname); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", - to_nmodl(node.get_lhs()), - varname, - to_nmodl(node)); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); - // TODO add symbol? + if (name->is_prime_name()) { + auto varname = "D" + name->get_node_name(); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + name->get_node_name(), + varname, + to_nmodl(node)); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); } } } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index 7178390ca8..2fb3b26297 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -47,8 +47,6 @@ class DerivativeOriginalVisitor: public AstVisitor { /// visiting derivative block bool derivative_block = false; - ast::AstNodeType node_type = ast::AstNodeType::NODE; - public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index 58c0bd9c4f..4533de36d5 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -34,12 +34,11 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative SUFFIX example } - STATE {x z[2]} + STATE {x z} DERIVATIVE equation { - x' = -x - z'[0] = x - z'[1] = x + z[0] + x' = -x + z * z + z' = z * x } )"; auto ast = run_derivative_original_visitor(nmodl_text); From 0267fbdf9c1f5a2fca08da574bd40a5d7949305c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 11:45:09 +0200 Subject: [PATCH 07/74] Update block description --- src/language/codegen.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index ac92afb517..e31cb4aca4 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -100,16 +100,15 @@ brief: "Block with statements vector" type: StatementBlock getter: {override: true} - brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars" description: | The original `DERIVATIVE` block in NMODL is replaced in-place if the system of ODEs is solvable analytically. Therefore, this - block's sole purpose is to keep the - original, unsolved block in the AST. This is - primarily useful when we need to solve the - ODE system using implicit methods, for - instance, CVode. + block's sole purpose is to keep the unsolved + block in the AST. This is primarily useful + when we need to solve the ODE system using + implicit methods, for instance, CVode. - WrappedExpression: brief: "Wrap any other expression type" From e58070f1b398139120336e4c8af74ba81b41d0e8 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 27 Sep 2024 15:29:28 +0200 Subject: [PATCH 08/74] Rename DERIVATIVE_ORIGINAL to CVODE --- src/language/code_generator.cmake | 2 +- src/language/codegen.yaml | 24 ++++++++------------ src/main.cpp | 6 ++--- src/visitors/derivative_original_visitor.cpp | 21 ++++++++--------- src/visitors/derivative_original_visitor.hpp | 7 +++--- src/visitors/sympy_solver_visitor.cpp | 5 ++-- src/visitors/sympy_solver_visitor.hpp | 3 +-- test/unit/visitor/derivative_original.cpp | 15 ++++++------ 8 files changed, 36 insertions(+), 47 deletions(-) diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index 992d5b0cb1..83ecff2eac 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -72,9 +72,9 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/constant_statement.hpp ${PROJECT_BINARY_DIR}/src/ast/constant_var.hpp ${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/cvode_block.hpp ${PROJECT_BINARY_DIR}/src/ast/define.hpp ${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp - ${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp ${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp ${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index e31cb4aca4..292cb567c8 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -88,27 +88,21 @@ - finalize_block: brief: "Statement block to be executed after calling linear solver" type: StatementBlock - - DerivativeOriginalFunctionBlock: - nmodl: "DERIVATIVE_ORIGINAL_FUNCTION " + - CvodeBlock: + nmodl: "CVODE_BLOCK " members: - name: - brief: "Name of the derivative block" + brief: "Name of the block" type: Name node_name: true suffix: {value: " "} - - statement_block: - brief: "Block with statements vector" + - function_block: + brief: "Block with statements of the form Dvar = f(var)" type: StatementBlock - getter: {override: true} - brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars" - description: | - The original `DERIVATIVE` block in NMODL is - replaced in-place if the system of ODEs is - solvable analytically. Therefore, this - block's sole purpose is to keep the unsolved - block in the AST. This is primarily useful - when we need to solve the ODE system using - implicit methods, for instance, CVode. + - diagonal_jacobian_block: + brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f))" + type: StatementBlock + brief: "Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks" - WrappedExpression: brief: "Wrap any other expression type" diff --git a/src/main.cpp b/src/main.cpp index f150753479..b3620dc46b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -499,10 +499,10 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_sparse = solver_exists(*ast, "sparse"); if (neuron_code) { - logger->info("Running derivative visitor"); - DerivativeOriginalVisitor().visit_program(*ast); + logger->info("Running cvode visitor"); + CvodeVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("derivative_original")); + ast_to_nmodl(*ast, filepath("cvode")); } if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 43b3eb6df8..97af9c32f7 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -21,27 +21,26 @@ namespace nmodl { namespace visitor { -void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { +void CvodeVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); - der_block_function = std::shared_ptr(node.clone()); + der_block = std::shared_ptr(node.clone()); } -void DerivativeOriginalVisitor::visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) { +void CvodeVisitor::visit_cvode_block(ast::CvodeBlock& node) { derivative_block = true; node.visit_children(*this); derivative_block = false; } -void DerivativeOriginalVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { +void CvodeVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { differential_equation = true; node.visit_children(*this); differential_equation = false; } -void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& node) { +void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { const auto& lhs = node.get_lhs(); /// we have to only solve ODEs under original derivative block where lhs is variable @@ -66,13 +65,13 @@ void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& n } } -void DerivativeOriginalVisitor::visit_program(ast::Program& node) { +void CvodeVisitor::visit_program(ast::Program& node) { program_symtab = node.get_symbol_table(); node.visit_children(*this); - if (der_block_function) { - auto der_node = - new ast::DerivativeOriginalFunctionBlock(der_block_function->get_name(), - der_block_function->get_statement_block()); + if (der_block) { + auto der_node = new ast::CvodeBlock(der_block->get_name(), + der_block->get_statement_block(), + der_block->get_statement_block()); node.emplace_back_node(der_node); } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index 2fb3b26297..9edb792186 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -33,10 +33,10 @@ namespace visitor { * solution. This block is inserted before that to prevent losing access to * information about the block. */ -class DerivativeOriginalVisitor: public AstVisitor { +class CvodeVisitor: public AstVisitor { private: /// The copy of the derivative block we are solving - std::shared_ptr der_block_function = nullptr; + std::shared_ptr der_block = nullptr; /// true while visiting differential equation bool differential_equation = false; @@ -50,8 +50,7 @@ class DerivativeOriginalVisitor: public AstVisitor { public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; - void visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) override; + void visit_cvode_block(ast::CvodeBlock& node) override; void visit_diff_eq_expression(ast::DiffEqExpression& node) override; void visit_binary_expression(ast::BinaryExpression& node) override; }; diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index e7b955a5c0..42936ae5e6 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -399,9 +399,8 @@ void SympySolverVisitor::visit_var_name(ast::VarName& node) { } } -// Skip visiting DERIVATIVE_ORIGINAL block -void SympySolverVisitor::visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) {} +// Skip visiting CVODE block +void SympySolverVisitor::visit_cvode_block(ast::CvodeBlock& node) {} void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { const auto& lhs = node.get_expression()->get_lhs(); diff --git a/src/visitors/sympy_solver_visitor.hpp b/src/visitors/sympy_solver_visitor.hpp index 627451d4b7..7642b79411 100644 --- a/src/visitors/sympy_solver_visitor.hpp +++ b/src/visitors/sympy_solver_visitor.hpp @@ -185,8 +185,7 @@ class SympySolverVisitor: public AstVisitor { void visit_expression_statement(ast::ExpressionStatement& node) override; void visit_statement_block(ast::StatementBlock& node) override; void visit_program(ast::Program& node) override; - void visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) override; + void visit_cvode_block(ast::CvodeBlock& node) override; }; /** @} */ // end of visitor_classes diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index 4533de36d5..5fbc60fcaa 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -17,17 +17,17 @@ using namespace test_utils; using nmodl::parser::NmodlDriver; -auto run_derivative_original_visitor(const std::string& text) { +auto run_cvode_visitor(const std::string& text) { NmodlDriver driver; const auto& ast = driver.parse_string(text); SymtabVisitor().visit_program(*ast); - DerivativeOriginalVisitor().visit_program(*ast); + CvodeVisitor().visit_program(*ast); return ast; } -TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { +TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][cvode]") { GIVEN("DERIVATIVE block") { std::string nmodl_text = R"( NEURON { @@ -41,12 +41,11 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative z' = z * x } )"; - auto ast = run_derivative_original_visitor(nmodl_text); - THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { - auto block = collect_nodes(*ast, - {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + auto ast = run_cvode_visitor(nmodl_text); + THEN("CVODE block is added") { + auto block = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); REQUIRE(!block.empty()); - THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + THEN("No primed variables exist in the CVODE block") { auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); REQUIRE(primed_vars.empty()); } From 044dfd93250209ae794fd37fa8ff8702574e5c48 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 09:14:34 +0200 Subject: [PATCH 09/74] Finish renaming --- src/main.cpp | 2 +- src/visitors/CMakeLists.txt | 2 +- ...ative_original_visitor.cpp => cvode_visitor.cpp} | 6 +++--- ...ative_original_visitor.hpp => cvode_visitor.hpp} | 13 ++++--------- test/unit/CMakeLists.txt | 2 +- .../visitor/{derivative_original.cpp => cvode.cpp} | 2 +- 6 files changed, 11 insertions(+), 16 deletions(-) rename src/visitors/{derivative_original_visitor.cpp => cvode_visitor.cpp} (91%) rename src/visitors/{derivative_original_visitor.hpp => cvode_visitor.hpp} (70%) rename test/unit/visitor/{derivative_original.cpp => cvode.cpp} (96%) diff --git a/src/main.cpp b/src/main.cpp index b3620dc46b..1ac71b752c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,7 +25,7 @@ #include "visitors/after_cvode_to_cnexp_visitor.hpp" #include "visitors/ast_visitor.hpp" #include "visitors/constant_folder_visitor.hpp" -#include "visitors/derivative_original_visitor.hpp" +#include "visitors/cvode_visitor.hpp" #include "visitors/function_callpath_visitor.hpp" #include "visitors/global_var_visitor.hpp" #include "visitors/implicit_argument_visitor.hpp" diff --git a/src/visitors/CMakeLists.txt b/src/visitors/CMakeLists.txt index ede77671eb..f51a65b732 100644 --- a/src/visitors/CMakeLists.txt +++ b/src/visitors/CMakeLists.txt @@ -11,7 +11,7 @@ add_library( visitor STATIC after_cvode_to_cnexp_visitor.cpp constant_folder_visitor.cpp - derivative_original_visitor.cpp + cvode_visitor.cpp defuse_analyze_visitor.cpp function_callpath_visitor.cpp global_var_visitor.cpp diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/cvode_visitor.cpp similarity index 91% rename from src/visitors/derivative_original_visitor.cpp rename to src/visitors/cvode_visitor.cpp index 97af9c32f7..ee60c9451b 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "visitors/derivative_original_visitor.hpp" +#include "visitors/cvode_visitor.hpp" #include "ast/all.hpp" #include "lexer/token_mapping.hpp" @@ -52,7 +52,7 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { if (name->is_prime_name()) { auto varname = "D" + name->get_node_name(); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + logger->debug("CvodeVisitor :: replacing {} with {} on LHS of {}", name->get_node_name(), varname, to_nmodl(node)); @@ -75,7 +75,7 @@ void CvodeVisitor::visit_program(ast::Program& node) { node.emplace_back_node(der_node); } - // re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block + // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/cvode_visitor.hpp similarity index 70% rename from src/visitors/derivative_original_visitor.hpp rename to src/visitors/cvode_visitor.hpp index 9edb792186..baeed0f84f 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -9,7 +9,7 @@ /** * \file - * \brief \copybrief nmodl::visitor::DerivativeOriginalVisitor + * \brief \copybrief nmodl::visitor::CvodeVisitor */ #include "symtab/decl.hpp" @@ -25,17 +25,12 @@ namespace visitor { */ /** - * \class DerivativeOriginalVisitor - * \brief Make a copy of the `DERIVATIVE` block (if it exists), and insert back as - * `DERIVATIVE_ORIGINAL_FUNCTION` block. - * - * If \ref SympySolverVisitor runs successfully, it replaces the original - * solution. This block is inserted before that to prevent losing access to - * information about the block. + * \class CvodeVisitor + * \brief Visitor used for generating the necessary AST nodes for CVODE */ class CvodeVisitor: public AstVisitor { private: - /// The copy of the derivative block we are solving + /// The copy of the derivative block of a given mod file std::shared_ptr der_block = nullptr; /// true while visiting differential equation diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 9ed95d8aff..f12d5167bb 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -45,7 +45,7 @@ add_executable( visitor/kinetic_block.cpp visitor/localize.cpp visitor/localrename.cpp - visitor/derivative_original.cpp + visitor/cvode.cpp visitor/local_to_assigned.cpp visitor/lookup.cpp visitor/loop_unroll.cpp diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/cvode.cpp similarity index 96% rename from test/unit/visitor/derivative_original.cpp rename to test/unit/visitor/cvode.cpp index 5fbc60fcaa..bdc4777665 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/cvode.cpp @@ -4,7 +4,7 @@ #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" #include "visitors/checkparent_visitor.hpp" -#include "visitors/derivative_original_visitor.hpp" +#include "visitors/cvode_visitor.hpp" #include "visitors/nmodl_visitor.hpp" #include "visitors/symtab_visitor.hpp" #include "visitors/visitor_utils.hpp" From 50f38cec0843fa5db2db4768783c198e21295222 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 12:06:12 +0200 Subject: [PATCH 10/74] Add item with Jacobian --- python/nmodl/ode.py | 25 +++++++++-- src/main.cpp | 17 ++++---- src/pybind/wrapper.cpp | 40 +++++++++++++++++- src/pybind/wrapper.hpp | 7 ++++ src/visitors/cvode_visitor.cpp | 77 +++++++++++++++++++++++++++++----- src/visitors/cvode_visitor.hpp | 12 ++++-- test/unit/visitor/cvode.cpp | 2 +- 7 files changed, 151 insertions(+), 29 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 3fe769e596..2eab38e873 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -608,7 +608,12 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = { + var if isinstance(var, str) else str(var): ( + sp.symbols(var, real=True) if isinstance(var, str) else var + ) + for var in vars + } sympy_vars[dependent_var] = x # parse string into SymPy equation @@ -643,15 +648,27 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): # differentiate w.r.t. x diff = expr.diff(x).simplify() + # could be something generic like f'(x), in which case we use finite differences + if needs_finite_differences(diff): + diff = ( + transform_expression(diff, discretize_derivative) + .subs({finite_difference_step_variable(x): 1e-3}) + .evalf() + ) + + # the codegen method does not like undefined function calls, so we extract + # them here + custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)} + # try to simplify expression in terms of existing variables # ignore any exceptions here, since we already have a valid solution # so if this further simplification step fails the error is not fatal try: # if expression is equal to one of the supplied vars, replace with this var # can do a simple string comparison here since a var cannot be further simplified - diff_as_string = sp.ccode(diff) + diff_as_string = sp.ccode(diff, user_functions=custom_fcts) for v in sympy_vars: - if diff_as_string == sp.ccode(sympy_vars[v]): + if diff_as_string == sp.ccode(sympy_vars[v], user_functions=custom_fcts): diff = sympy_vars[v] # or if equal to rhs of one of the supplied equations, replace with lhs @@ -672,4 +689,4 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): pass # return result as C code in NEURON format - return sp.ccode(diff.evalf()) + return sp.ccode(diff.evalf(), user_functions=custom_fcts) diff --git a/src/main.cpp b/src/main.cpp index 1ac71b752c..a4ea9e266b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -498,18 +498,19 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_linear = node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK); const bool sympy_sparse = solver_exists(*ast, "sparse"); - if (neuron_code) { - logger->info("Running cvode visitor"); - CvodeVisitor().visit_program(*ast); - SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("cvode")); - } - if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || - sympy_linear) { + sympy_linear || neuron_code) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); + + if (neuron_code) { + logger->info("Running CVODE visitor"); + CvodeVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("cvode")); + } + if (sympy_conductance) { logger->info("Running sympy conductance visitor"); SympyConductanceVisitor().visit_program(*ast); diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 32c390736c..ae9d414976 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -9,7 +9,7 @@ #include "codegen/codegen_naming.hpp" #include "pybind/pyembed.hpp" - +#include #include #include @@ -186,6 +186,41 @@ except Exception as e: return {std::move(solution), std::move(exception_message)}; } +/// \brief A blunt instrument that differentiates expression w.r.t. variable +/// \return The tuple (solution, exception) +std::tuple call_diff2c( + const std::string& expression, + const std::string& variable, + const std::unordered_map& indexed_vars) { + std::string statements; + // only indexed variables require special treatment + for (const auto& [var, prop]: indexed_vars) { + statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var); + } + auto locals = py::dict("expression"_a = expression, "variable"_a = variable); + std::string script = fmt::format(R"( +_allvars = [] +{} +exception_message = "" +try: + solution = differentiate2c(expression, + variable, + _allvars, + ) +except Exception as e: + # if we fail, fail silently and return empty string + solution = "" + exception_message = str(e) +)", + statements); + + py::exec(nmodl::pybind_wrappers::ode_py + script, locals); + + auto solution = locals["solution"].cast(); + auto exception_message = locals["exception_message"].cast(); + + return {std::move(solution), std::move(exception_message)}; +} void initialize_interpreter_func() { pybind11::initialize_interpreter(true); @@ -203,7 +238,8 @@ NMODL_EXPORT pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept { &call_solve_nonlinear_system, &call_solve_linear_system, &call_diffeq_solver, - &call_analytic_diff}; + &call_analytic_diff, + &call_diff2c}; } } diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index 725f9f8113..b4ec0a2dff 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace nmodl { @@ -44,6 +45,11 @@ std::tuple call_analytic_diff( const std::vector& expressions, const std::set& used_names_in_block); +std::tuple call_diff2c( + const std::string& expression, + const std::string& variable, + const std::unordered_map& indexed_vars = {}); + struct pybind_wrap_api { decltype(&initialize_interpreter_func) initialize_interpreter; decltype(&finalize_interpreter_func) finalize_interpreter; @@ -51,6 +57,7 @@ struct pybind_wrap_api { decltype(&call_solve_linear_system) solve_linear_system; decltype(&call_diffeq_solver) diffeq_solver; decltype(&call_analytic_diff) analytic_diff; + decltype(&call_diff2c) diff2c; }; #ifdef _WIN32 diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index ee60c9451b..c326fd36ef 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -20,23 +20,56 @@ namespace pywrap = nmodl::pybind_wrappers; namespace nmodl { namespace visitor { +static int get_index(const ast::IndexedName& node) { + return std::stoi(to_nmodl(node.get_length())); +} + +static auto get_name_map(const ast::Expression& node, const std::string& name) { + std::unordered_map name_map; + // all of the "reserved" symbols + auto reserved_symbols = get_external_functions(); + // all indexed vars + auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); + for (const auto& var: indexed_vars) { + if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && + std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { + return var->get_node_name() == item; + })) { + logger->debug( + "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "node_map", + var->get_node_name()); + name_map[var->get_node_name()] = get_index( + *std::dynamic_pointer_cast(var)); + } + } + return name_map; +} void CvodeVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); - der_block = std::shared_ptr(node.clone()); + derivative_block = std::shared_ptr(node.clone()); } void CvodeVisitor::visit_cvode_block(ast::CvodeBlock& node) { - derivative_block = true; + in_cvode_block = true; node.visit_children(*this); - derivative_block = false; + in_cvode_block = false; } void CvodeVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { - differential_equation = true; + in_differential_equation = true; node.visit_children(*this); - differential_equation = false; + in_differential_equation = false; +} + + +void CvodeVisitor::visit_statement_block(ast::StatementBlock& node) { + node.visit_children(*this); + if (in_cvode_block) { + ++block_index; + } } @@ -44,7 +77,7 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { const auto& lhs = node.get_lhs(); /// we have to only solve ODEs under original derivative block where lhs is variable - if (!derivative_block || !differential_equation || !lhs->is_var_name()) { + if (!in_cvode_block || !in_differential_equation || !lhs->is_var_name()) { return; } @@ -62,16 +95,40 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } + if (block_index == 1) { + auto rhs = node.get_rhs(); + // map of all indexed symbols (need special treatment in SymPy) + auto name_map = get_name_map(*rhs, name->get_node_name()); + auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; + auto [jacobian, + exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); + if (!exception_message.empty()) { + logger->warn("DerivativeOriginalVisitor :: python exception: {}", + exception_message); + } + // NOTE: LHS can be anything here, the equality is to keep `create_statement` from + // complaining, we discard the LHS later + auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); + logger->debug("DerivativeOriginalVisitor :: replacing statement {} with {}", + to_nmodl(node), + statement); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); + } } } void CvodeVisitor::visit_program(ast::Program& node) { program_symtab = node.get_symbol_table(); node.visit_children(*this); - if (der_block) { - auto der_node = new ast::CvodeBlock(der_block->get_name(), - der_block->get_statement_block(), - der_block->get_statement_block()); + if (derivative_block) { + auto der_node = new ast::CvodeBlock(derivative_block->get_name(), + derivative_block->get_statement_block(), + std::shared_ptr( + derivative_block->get_statement_block()->clone())); node.emplace_back_node(der_node); } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index baeed0f84f..0bf336cfd4 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -31,16 +31,19 @@ namespace visitor { class CvodeVisitor: public AstVisitor { private: /// The copy of the derivative block of a given mod file - std::shared_ptr der_block = nullptr; + std::shared_ptr derivative_block = nullptr; /// true while visiting differential equation - bool differential_equation = false; + bool in_differential_equation = false; /// global symbol table symtab::SymbolTable* program_symtab = nullptr; - /// visiting derivative block - bool derivative_block = false; + /// true while we are visiting a CVODE block + bool in_cvode_block = false; + + /// index of the block to modify (0 = function block, 1 = Jacobian block) + int block_index = 0; public: void visit_derivative_block(ast::DerivativeBlock& node) override; @@ -48,6 +51,7 @@ class CvodeVisitor: public AstVisitor { void visit_cvode_block(ast::CvodeBlock& node) override; void visit_diff_eq_expression(ast::DiffEqExpression& node) override; void visit_binary_expression(ast::BinaryExpression& node) override; + void visit_statement_block(ast::StatementBlock& node) override; }; /** \} */ // end of visitor_classes diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index bdc4777665..3a57d242d4 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -27,7 +27,7 @@ auto run_cvode_visitor(const std::string& text) { } -TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][cvode]") { +TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { GIVEN("DERIVATIVE block") { std::string nmodl_text = R"( NEURON { From f82fe1f531b6c9b500040a959412eddf292892fe Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 12:39:30 +0200 Subject: [PATCH 11/74] Do not use an int but an enum-wrapped int --- src/visitors/cvode_visitor.cpp | 2 +- src/visitors/cvode_visitor.hpp | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index c326fd36ef..0ed1a97573 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -95,7 +95,7 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } - if (block_index == 1) { + if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) auto name_map = get_name_map(*rhs, name->get_node_name()); diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index 0bf336cfd4..7dcef42839 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -19,6 +19,16 @@ namespace nmodl { namespace visitor { +enum class BlockIndex { FUNCTION = 0, JACOBIAN = 1 }; + +inline BlockIndex& operator++(BlockIndex& index) { + if (index == BlockIndex::FUNCTION) { + index = BlockIndex::JACOBIAN; + } else { + index = BlockIndex::FUNCTION; + } + return index; +} /** * \addtogroup visitor_classes * \{ @@ -42,8 +52,8 @@ class CvodeVisitor: public AstVisitor { /// true while we are visiting a CVODE block bool in_cvode_block = false; - /// index of the block to modify (0 = function block, 1 = Jacobian block) - int block_index = 0; + /// index of the block to modify + BlockIndex block_index = BlockIndex::FUNCTION; public: void visit_derivative_block(ast::DerivativeBlock& node) override; From bd2fd36a5c2f961ff4e9bae724daa161647dbb93 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 14:01:31 +0200 Subject: [PATCH 12/74] Add support for diffing expressions with indexed vars --- python/nmodl/ode.py | 7 ++++++- test/unit/ode/test_ode.py | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 3fe769e596..2e110b3842 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -608,7 +608,12 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = { + var if isinstance(var, str) else str(var): ( + sp.symbols(var, real=True) if isinstance(var, str) else var + ) + for var in vars + } sympy_vars[dependent_var] = x # parse string into SymPy equation diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 387cfb801f..0c195bc02b 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -28,7 +28,12 @@ def _equivalent( """ lhs = lhs.replace("pow(", "Pow(") rhs = rhs.replace("pow(", "Pow(") - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = { + var if isinstance(var, str) else str(var): ( + sp.symbols(var, real=True) if isinstance(var, str) else var + ) + for var in vars + } for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)): eq_l = sp.sympify(l, locals=sympy_vars) eq_r = sp.sympify(r, locals=sympy_vars) @@ -100,6 +105,16 @@ def test_differentiate2c(): "g", ) + assert _equivalent( + differentiate2c( + "(s[0] + s[1])*(z[0]*z[1]*z[2])*x", + "x", + {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, + ), + "(s[0] + s[1])*(z[0]*z[1]*z[2])", + {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, + ) + def test_integrate2c(): From b082f0d29b3a1f1d8d30e9c1d43375ac2f15219b Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 14:29:28 +0200 Subject: [PATCH 13/74] Allow diffing implicit functions in `differentiate2c` Uses finite differences --- python/nmodl/ode.py | 18 +++++++++++++++--- test/unit/ode/test_ode.py | 10 ++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 3fe769e596..66b3a752e2 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -643,15 +643,27 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): # differentiate w.r.t. x diff = expr.diff(x).simplify() + # could be something generic like f'(x), in which case we use finite differences + if needs_finite_differences(diff): + diff = ( + transform_expression(diff, discretize_derivative) + .subs({finite_difference_step_variable(x): 1e-3}) + .evalf() + ) + + # the codegen method does not like undefined function calls, so we extract + # them here + custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)} + # try to simplify expression in terms of existing variables # ignore any exceptions here, since we already have a valid solution # so if this further simplification step fails the error is not fatal try: # if expression is equal to one of the supplied vars, replace with this var # can do a simple string comparison here since a var cannot be further simplified - diff_as_string = sp.ccode(diff) + diff_as_string = sp.ccode(diff, user_functions=custom_fcts) for v in sympy_vars: - if diff_as_string == sp.ccode(sympy_vars[v]): + if diff_as_string == sp.ccode(sympy_vars[v], user_functions=custom_fcts): diff = sympy_vars[v] # or if equal to rhs of one of the supplied equations, replace with lhs @@ -672,4 +684,4 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): pass # return result as C code in NEURON format - return sp.ccode(diff.evalf()) + return sp.ccode(diff.evalf(), user_functions=custom_fcts) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 387cfb801f..0d5e7f628a 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -100,6 +100,16 @@ def test_differentiate2c(): "g", ) + assert _equivalent( + differentiate2c( + "-f(x)", + "x", + {}, + ), + "1000.0*f(x - 0.00050000000000000001) - 1000.0*f(x + 0.00050000000000000001)", + {"x"}, + ) + def test_integrate2c(): From edf33a70000bff2908267a5a45251a6be6381675 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 16:34:44 +0200 Subject: [PATCH 14/74] Simplify condition --- python/nmodl/ode.py | 4 +--- test/unit/ode/test_ode.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 2e110b3842..4e5b9be253 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -609,9 +609,7 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars.discard(dependent_var) # declare all other supplied variables sympy_vars = { - var if isinstance(var, str) else str(var): ( - sp.symbols(var, real=True) if isinstance(var, str) else var - ) + str(var): (sp.symbols(var, real=True) if isinstance(var, str) else var) for var in vars } sympy_vars[dependent_var] = x diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 0c195bc02b..33810c16da 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -29,9 +29,7 @@ def _equivalent( lhs = lhs.replace("pow(", "Pow(") rhs = rhs.replace("pow(", "Pow(") sympy_vars = { - var if isinstance(var, str) else str(var): ( - sp.symbols(var, real=True) if isinstance(var, str) else var - ) + str(var): (sp.symbols(var, real=True) if isinstance(var, str) else var) for var in vars } for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)): From 565fa03c1d9647badc146e82b24117ab7f8b3272 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 1 Oct 2024 11:02:57 +0200 Subject: [PATCH 15/74] Better testing --- test/unit/ode/test_ode.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 0d5e7f628a..21ee8a88f9 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 from nmodl.ode import differentiate2c, integrate2c +import numpy as np import sympy as sp @@ -100,15 +101,29 @@ def test_differentiate2c(): "g", ) - assert _equivalent( - differentiate2c( - "-f(x)", - "x", - {}, - ), - "1000.0*f(x - 0.00050000000000000001) - 1000.0*f(x + 0.00050000000000000001)", - {"x"}, + result = differentiate2c( + "-f(x)", + "x", + {}, ) + # instead of comparing the expression as a string, we convert the string + # back to an expression and insert various functions + for function in [sp.sin, sp.exp, sp.tanh]: + for value in np.linspace(-5, 5, 100): + np.testing.assert_allclose( + float( + sp.sympify(result) + .subs(sp.Function("f"), function) + .subs({"x": value}) + .evalf() + ), + float( + -sp.Derivative(function("x")) + .as_finite_difference(1e-3) + .subs({"x": value}) + .evalf() + ), + ) def test_integrate2c(): From 6bd6aed914cfc87bdf1a2b54ec39c6bc8fc3c654 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 1 Oct 2024 15:45:22 +0200 Subject: [PATCH 16/74] Add suggestions from code review --- test/unit/ode/test_ode.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 21ee8a88f9..e3a25b06e5 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -107,23 +107,22 @@ def test_differentiate2c(): {}, ) # instead of comparing the expression as a string, we convert the string - # back to an expression and insert various functions - for function in [sp.sin, sp.exp, sp.tanh]: - for value in np.linspace(-5, 5, 100): - np.testing.assert_allclose( - float( - sp.sympify(result) - .subs(sp.Function("f"), function) - .subs({"x": value}) - .evalf() - ), - float( - -sp.Derivative(function("x")) - .as_finite_difference(1e-3) - .subs({"x": value}) - .evalf() - ), - ) + # back to an expression and compare with an explicit function + for value in np.linspace(-5, 5, 100): + np.testing.assert_allclose( + float( + sp.sympify(result) + .subs(sp.Function("f"), sp.sin) + .subs({"x": value}) + .evalf() + ), + float( + -sp.Derivative(sp.sin("x")) + .as_finite_difference(1e-3) + .subs({"x": value}) + .evalf() + ), + ) def test_integrate2c(): From 0eba407672a868fdcc24a5c66e84fa7d175ca3ba Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 11:33:14 +0200 Subject: [PATCH 17/74] Add `stepsize` param to `differentiate2c` --- python/nmodl/ode.py | 14 ++++++++++++-- test/unit/ode/test_ode.py | 8 ++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 66b3a752e2..e40cb47c62 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -568,7 +568,13 @@ def forwards_euler2c(diff_string, dt_var, vars, function_calls): return f"{sp.ccode(x)} = {sp.ccode(solution, user_functions=custom_fcts)}" -def differentiate2c(expression, dependent_var, vars, prev_expressions=None): +def differentiate2c( + expression, + dependent_var, + vars, + prev_expressions=None, + stepsize=1e-3, +): """Analytically differentiate supplied expression, return solution as C code. Expression should be of the form "f(x)", where "x" is @@ -595,11 +601,15 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars: set of all other variables used in expression, e.g. {"a", "b", "c"} prev_expressions: time-ordered list of preceeding expressions to evaluate & substitute, e.g. ["b = x + c", "a = 12*b"] + stepsize: in case an analytic expression is not possible, finite differences are used; + this argument sets the step size Returns: string containing analytic derivative of expression (including any substitutions of variables from supplied prev_expressions) w.r.t. dependent_var as C code. """ + if stepsize <= 0: + raise ValueError("arg `stepsize` must be > 0") prev_expressions = prev_expressions or [] # every symbol (a.k.a variable) that SymPy # is going to manipulate needs to be declared @@ -647,7 +657,7 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): if needs_finite_differences(diff): diff = ( transform_expression(diff, discretize_derivative) - .subs({finite_difference_step_variable(x): 1e-3}) + .subs({finite_difference_step_variable(x): stepsize}) .evalf() ) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index e3a25b06e5..390c938f9a 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -5,6 +5,7 @@ from nmodl.ode import differentiate2c, integrate2c import numpy as np +import pytest import sympy as sp @@ -123,6 +124,13 @@ def test_differentiate2c(): .evalf() ), ) + with pytest.raises(ValueError): + differentiate2c( + "-f(x)", + "x", + {}, + stepsize=-1, + ) def test_integrate2c(): From c1e7fd3bb9657d07f5d567e435d1893e4a8d06c2 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 12:45:44 +0200 Subject: [PATCH 18/74] Try Python 3.9 maybe? --- .github/workflows/nmodl-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nmodl-ci.yml b/.github/workflows/nmodl-ci.yml index 09a2a4d20d..3d981241cc 100644 --- a/.github/workflows/nmodl-ci.yml +++ b/.github/workflows/nmodl-ci.yml @@ -16,7 +16,7 @@ on: env: CTEST_PARALLEL_LEVEL: 1 - PYTHON_VERSION: 3.8 + PYTHON_VERSION: 3.9 DESIRED_CMAKE_VERSION: 3.15.0 jobs: From 45629455713cb281911dd2d81bfbfd3e2686200e Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 14:33:13 +0200 Subject: [PATCH 19/74] Add codegen --- src/codegen/codegen_helper_visitor.cpp | 5 + src/codegen/codegen_helper_visitor.hpp | 1 + src/codegen/codegen_info.hpp | 3 + src/codegen/codegen_neuron_cpp_visitor.cpp | 170 ++++++++++++++++++++- src/codegen/codegen_neuron_cpp_visitor.hpp | 3 + src/visitors/cvode_visitor.cpp | 7 +- test/usecases/CMakeLists.txt | 3 +- test/usecases/cvode/derivative.mod | 46 ++++++ test/usecases/cvode/test_cvode.py | 57 +++++++ 9 files changed, 286 insertions(+), 9 deletions(-) create mode 100644 test/usecases/cvode/derivative.mod create mode 100644 test/usecases/cvode/test_cvode.py diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index ad9dd393de..f098c0fb0b 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -636,6 +636,11 @@ void CodegenHelperVisitor::visit_nrn_state_block(const ast::NrnStateBlock& node) node.visit_children(*this); } +void CodegenHelperVisitor::visit_cvode_block(const ast::CvodeBlock& node) { + info.cvode_block = &node; + node.visit_children(*this); +} + void CodegenHelperVisitor::visit_procedure_block(const ast::ProcedureBlock& node) { info.procedures.push_back(&node); diff --git a/src/codegen/codegen_helper_visitor.hpp b/src/codegen/codegen_helper_visitor.hpp index 9258f1cf7f..52fe71030d 100644 --- a/src/codegen/codegen_helper_visitor.hpp +++ b/src/codegen/codegen_helper_visitor.hpp @@ -108,6 +108,7 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor { void visit_program(const ast::Program& node) override; void visit_factor_def(const ast::FactorDef& node) override; void visit_nrn_state_block(const ast::NrnStateBlock& node) override; + void visit_cvode_block(const ast::CvodeBlock& node) override; void visit_linear_block(const ast::LinearBlock& node) override; void visit_non_linear_block(const ast::NonLinearBlock& node) override; void visit_discrete_block(const ast::DiscreteBlock& node) override; diff --git a/src/codegen/codegen_info.hpp b/src/codegen/codegen_info.hpp index 6d39479e1a..d760f68740 100644 --- a/src/codegen/codegen_info.hpp +++ b/src/codegen/codegen_info.hpp @@ -414,6 +414,9 @@ struct CodegenInfo { /// nrn_state block const ast::NrnStateBlock* nrn_state_block = nullptr; + /// the CVODE block + const ast::CvodeBlock* cvode_block = nullptr; + /// net receive block for point process const ast::NetReceiveBlock* net_receive_node = nullptr; diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 7e012f6b08..bcacc1471b 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -856,6 +856,14 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in printer->fmt_line("static Symbol* _{}_sym;", ion.name); } + if (info.emit_cvode) { + printer->add_line("static Symbol** _atollist;"); + printer->push_block("static HocStateTolerance _hoc_state_tol[] ="); + // TODO: add stuff that iterates over `rangestate` in NOCMODL + printer->add_line("{0, 0}"); + printer->pop_block(";"); + } + printer->add_line("static int mech_type;"); if (info.point_process) { @@ -1280,16 +1288,22 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { i)); } + if (info.emit_cvode) { + mech_register_args.push_back( + // TODO: figure out why the first parameter should be called "_cvode_ieq" + fmt::format("_nrn_mechanism_field{{\"_cvode_ieq\", \"cvodeieq\"}} /* {} */", + codegen_int_variables_size)); + } + printer->add_multi_line(fmt::format("{}", fmt::join(mech_register_args, ",\n"))); printer->decrease_indent(); printer->add_line(");"); printer->add_newline(); - printer->fmt_line("hoc_register_prop_size(mech_type, {}, {});", float_variables_size(), - int_variables_size()); + int_variables_size() + static_cast(info.emit_cvode)); for (int i = 0; i < codegen_int_variables_size; ++i) { if (i != info.semantics[i].index) { @@ -1349,6 +1363,18 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { global_struct_instance()); } + if (info.emit_cvode) { + printer->fmt_line("hoc_register_dparam_semantics(mech_type, {}, \"cvodeieq\");", + codegen_int_variables_size); + printer->fmt_line( + "hoc_register_cvode(mech_type, ode_count_{}, ode_map_{}, ode_spec_{}, ode_matsol_{});", + info.mod_suffix, + info.mod_suffix, + info.mod_suffix, + info.mod_suffix); + printer->fmt_line("hoc_register_tolerance(mech_type, _hoc_state_tol, &_atollist);"); + } + printer->pop_block(); } @@ -1757,9 +1783,9 @@ void CodegenNeuronCppVisitor::print_nrn_alloc() { )CODE"); printer->chain_block("else"); } - if (info.semantic_variable_count) { + if (info.semantic_variable_count || info.emit_cvode) { printer->fmt_line("_ppvar = nrn_prop_datum_alloc(mech_type, {}, _prop);", - info.semantic_variable_count); + info.semantic_variable_count + static_cast(info.emit_cvode)); printer->add_line("_nrn_mechanism_access_dparam(_prop) = _ppvar;"); } printer->add_multi_line(R"CODE( @@ -2201,6 +2227,8 @@ void CodegenNeuronCppVisitor::print_mechanism_variables_macros() { if (info.table_count > 0) { printer->add_line("void _nrn_thread_table_reg(int, nrn_thread_table_check_t);"); } + // for CVODE + printer->add_line("extern void _cvode_abstol(Symbol**, double*, int);"); if (info.for_netcon_used) { printer->add_line("int _nrn_netcon_args(void*, double***);"); } @@ -2273,6 +2301,7 @@ void CodegenNeuronCppVisitor::print_codegen_routines() { print_nrn_destructor_declaration(); print_nrn_alloc(); print_function_prototypes(); + print_cvode_definitions(); print_point_process_function_definitions(); print_setdata_functions(); print_check_table_entrypoint(); @@ -2401,6 +2430,139 @@ void CodegenNeuronCppVisitor::print_net_receive_common_code() { printer->add_line("double t = nt->_t;"); } +void CodegenNeuronCppVisitor::print_cvode_definitions() { + if (!info.emit_cvode) { + return; + } + printer->add_newline(2); + printer->add_line("/* Functions related to CVODE codegen */"); + + /* return # of ODEs to solve */ + printer->push_block( + fmt::format("static constexpr int ode_count_{}(int _type)", info.mod_suffix)); + printer->fmt_line("return {};", info.num_equations); + printer->pop_block(); + + printer->add_newline(2); + + const ParamVector args_setup = {{"", "const _nrn_model_sorted_token&", "", "_sorted_token"}, + {"", "NrnThread*", "", "nt"}, + {"", "Memb_list*", "", "_ml_arg"}, + {"", "int", "", "_type"}}; + + ParamVector args_cvode = {{"", "_nrn_mechanism_cache_range&", "", "_lmc"}, + {"", fmt::format("{}_Instance&", info.mod_suffix), "", "inst"}, + {"", fmt::format("{}_NodeData&", info.mod_suffix), "", "node_data"}, + {"", "size_t", "", "id"}, + {"", "Datum*", "", "_ppvar"}, + {"", "Datum*", "", "_thread"}, + {"", "NrnThread*", "", "nt"}}; + + if (info.thread_callback_register) { + auto type_name = fmt::format("{}&", thread_variables_struct()); + args_cvode.emplace_back("", type_name, "", "_thread_vars"); + } + + /* The internal spec function */ + printer->fmt_push_block("static int ode_spec1_{}({})", + info.mod_suffix, + get_parameter_str(args_cvode)); // begin function definition + printer->add_line("int node_id = node_data.nodeindices[id];"); + printer->add_line("auto v = node_data.node_voltages[node_id];"); + if (info.cvode_block) { + auto block = info.cvode_block->get_function_block(); + print_statement_block(*block, false, false); + } + + printer->add_line("return 0;"); + printer->pop_block(); // end function definition + + printer->add_newline(2); + + /* Main spec function */ + printer->push_block(fmt::format("static void ode_spec_{}({})", + info.mod_suffix, + get_parameter_str(args_setup))); // begin function definition + printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); + printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + printer->add_line("auto nodecount = _ml_arg->nodecount;"); + printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); + printer->add_line("auto* _thread = _ml_arg->_thread;"); + if (!codegen_thread_variables.empty()) { + printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", + thread_variables_struct(), + info.thread_var_thread_id); + } + printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for loop + printer->add_line("int node_id = node_data.nodeindices[id];"); + printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); + printer->add_line("auto v = node_data.node_voltages[node_id];"); + printer->fmt_line("ode_spec1_{}({});", info.mod_suffix, get_arg_str(args_cvode)); + + printer->pop_block(); // end for loop + printer->pop_block(); // end function definition + + printer->add_newline(2); + + /* map */ + printer->push_block( + fmt::format("static void ode_map_{}(Prop* _prop, int equation_index, " + "neuron::container::data_handle* _pv, " + "neuron::container::data_handle* _pvdot, double* _atol, int _type)", + info.mod_suffix)); // begin function definition + printer->add_line("auto* _ppvar = _nrn_mechanism_access_dparam(_prop);"); + printer->fmt_line("_ppvar[{}].literal_value() = equation_index;", int_variables_size()); + printer->push_block(fmt::format("for (int i = 0; i < ode_count_{}(0); i++)", + info.mod_suffix)); // begin for loop + printer->add_line("_pv[i] = _nrn_mechanism_get_param_handle(_prop, _slist1[i]);"); + printer->add_line("_pvdot[i] = _nrn_mechanism_get_param_handle(_prop, _dlist1[i]);"); + printer->add_line("_cvode_abstol(_atollist, _atol, i);"); + printer->pop_block(); // end for loop + printer->pop_block(); // end function definition + + printer->add_newline(2); + + /* matsol instance (?) */ + printer->push_block(fmt::format("static void ode_matsol_instance1_{}({})", + info.mod_suffix, + get_parameter_str(args_cvode))); // begin function definition + + if (info.cvode_block) { + // for mathematical details, see eq. (4.8) in: + // https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html + auto block = info.cvode_block->get_diagonal_jacobian_block(); + print_statement_block(*block, false, false); + } + + printer->pop_block(); // end function definition + + printer->add_newline(2); + + /* matsol */ + printer->push_block(fmt::format("static void ode_matsol_{}({})", + info.mod_suffix, + get_parameter_str(args_setup))); // begin function definition + printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); + printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + printer->add_line("auto nodecount = _ml_arg->nodecount;"); + printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); + printer->add_line("auto* _thread = _ml_arg->_thread;"); + if (!codegen_thread_variables.empty()) { + printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", + thread_variables_struct(), + info.thread_var_thread_id); + } + printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for loop + printer->add_line("int node_id = node_data.nodeindices[id];"); + printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); + printer->add_line("auto v = node_data.node_voltages[node_id];"); + // TODO check if this can be replaced by ode_matsol1 + printer->fmt_line("ode_matsol_instance1_{}({});", info.mod_suffix, get_arg_str(args_cvode)); + + printer->pop_block(); // end for loop + printer->pop_block(); // end function definition +} + void CodegenNeuronCppVisitor::print_net_receive() { printing_net_receive = true; auto node = info.net_receive_node; diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index b1a02fc7d3..51136f1b29 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -686,6 +686,9 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_ion_variable() override; + void print_cvode_definitions(); + + /****************************************************************************************/ /* Overloaded visitor routines */ /****************************************************************************************/ diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 0ed1a97573..da053061d2 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -36,7 +36,7 @@ static auto get_name_map(const ast::Expression& node, const std::string& name) { return var->get_node_name() == item; })) { logger->debug( - "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "CvodeVisitor :: adding INDEXED_VARIABLE {} to " "node_map", var->get_node_name()); name_map[var->get_node_name()] = get_index( @@ -103,13 +103,12 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { auto [jacobian, exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); if (!exception_message.empty()) { - logger->warn("DerivativeOriginalVisitor :: python exception: {}", - exception_message); + logger->warn("CvodeVisitor :: python exception: {}", exception_message); } // NOTE: LHS can be anything here, the equality is to keep `create_statement` from // complaining, we discard the LHS later auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); - logger->debug("DerivativeOriginalVisitor :: replacing statement {} with {}", + logger->debug("CvodeVisitor :: replacing statement {} with {}", to_nmodl(node), statement); auto expr_statement = std::dynamic_pointer_cast( diff --git a/test/usecases/CMakeLists.txt b/test/usecases/CMakeLists.txt index feaa81ef7a..9aff9e19a1 100644 --- a/test/usecases/CMakeLists.txt +++ b/test/usecases/CMakeLists.txt @@ -28,7 +28,8 @@ set(NMODL_USECASE_DIRS steady_state suffix table - useion) + useion + cvode) foreach(usecase ${NMODL_USECASE_DIRS}) add_test(NAME usecase_${usecase} diff --git a/test/usecases/cvode/derivative.mod b/test/usecases/cvode/derivative.mod new file mode 100644 index 0000000000..70aa20352c --- /dev/null +++ b/test/usecases/cvode/derivative.mod @@ -0,0 +1,46 @@ +NEURON { + SUFFIX scalar +} + +PARAMETER { + freq = 10 + a = 5 + v1 = -1 + v2 = 5 + v3 = 15 + v4 = 0.8 + v5 = 0.3 + r = 3 + k = 0.2 + nmodl_alpha = 1.2 + nmodl_beta = 4.5 + nmodl_gamma = 2.4 + nmodl_delta = 7.5 +} + +STATE {var1 var2 var3 var4 var5} + +INITIAL { + var1 = v1 + var2 = v2 + var3 = v3 + var4 = v4 + var5 = v5 +} + +BREAKPOINT { + SOLVE equation METHOD derivimplicit +} + + +DERIVATIVE equation { + : eq with a function on RHS + var1' = -sin(freq * t) + : simple ODE (nonzero Jacobian) + var2' = -var2 * a + : logistic ODE + var3' = r * var3 * (1 - var3 / k) + : system of 2 ODEs (predator-prey model) + var4' = nmodl_alpha * var4 - nmodl_beta * var4 * var5 + var5' = nmodl_delta * var4 * var5 - nmodl_gamma * var5 +} diff --git a/test/usecases/cvode/test_cvode.py b/test/usecases/cvode/test_cvode.py new file mode 100644 index 0000000000..521b93f673 --- /dev/null +++ b/test/usecases/cvode/test_cvode.py @@ -0,0 +1,57 @@ +import numpy as np +from neuron import gui +from neuron import h +from neuron.units import ms + + +def simulate(rtol): + nseg = 1 + mech = "scalar" + + s = h.Section() + cvode = h.CVode() + cvode.active(True) + cvode.atol(1e-10) + s.insert(mech) + s.nseg = nseg + + t_hoc = h.Vector().record(h._ref_t) + var1_hoc = h.Vector().record(getattr(s(0.5), f"_ref_var1_{mech}")) + var2_hoc = h.Vector().record(getattr(s(0.5), f"_ref_var2_{mech}")) + var3_hoc = h.Vector().record(getattr(s(0.5), f"_ref_var3_{mech}")) + + h.stdinit() + h.tstop = 2.0 * ms + h.run() + + freq = getattr(h, f"freq_{mech}") + a = getattr(h, f"a_{mech}") + v1 = getattr(h, f"v1_{mech}") + v2 = getattr(h, f"v2_{mech}") + v3 = getattr(h, f"v3_{mech}") + r = getattr(h, f"r_{mech}") + k = getattr(h, f"k_{mech}") + + t = np.array(t_hoc.as_numpy()) + var1 = np.array(var1_hoc.as_numpy()) + var2 = np.array(var2_hoc.as_numpy()) + var3 = np.array(var3_hoc.as_numpy()) + + var1_exact = (np.cos(t * freq) + v1 * freq - 1) / freq + var2_exact = v2 * np.exp(-t * a) + var3_exact = k * v3 / (v3 + (k - v3) * np.exp(-r * t)) + + np.testing.assert_allclose(var1, var1_exact, rtol=rtol) + np.testing.assert_allclose(var2, var2_exact, rtol=rtol) + np.testing.assert_allclose(var3, var3_exact, rtol=rtol) + + return t, var1, var2, var3 + + +if __name__ == "__main__": + t, *x = simulate(rtol=1e-5) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(nrows=len(x)) + # for a, val in zip(ax, x): + # a.plot(t, val, ls="", marker="x", markersize=0.1) + # plt.show() From dcf47abb4fdc436fef8d3ba67880f02c1c9c9ee7 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 14:35:43 +0200 Subject: [PATCH 20/74] Spurious change --- .github/workflows/nmodl-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nmodl-ci.yml b/.github/workflows/nmodl-ci.yml index 3d981241cc..09a2a4d20d 100644 --- a/.github/workflows/nmodl-ci.yml +++ b/.github/workflows/nmodl-ci.yml @@ -16,7 +16,7 @@ on: env: CTEST_PARALLEL_LEVEL: 1 - PYTHON_VERSION: 3.9 + PYTHON_VERSION: 3.8 DESIRED_CMAKE_VERSION: 3.15.0 jobs: From 34371ccfb22159b697714355c26dd7248b7a5c5d Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 14:50:58 +0200 Subject: [PATCH 21/74] Stop using numpy to please MacOS dual-arch build --- test/unit/ode/test_ode.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 390c938f9a..df5f6c4f0a 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 from nmodl.ode import differentiate2c, integrate2c -import numpy as np import pytest import sympy as sp @@ -109,20 +108,22 @@ def test_differentiate2c(): ) # instead of comparing the expression as a string, we convert the string # back to an expression and compare with an explicit function - for value in np.linspace(-5, 5, 100): - np.testing.assert_allclose( + size = 100 + for index in range(size): + a, b = -5, 5 + value = (b - a) * index / size + a + pytest.approx( float( sp.sympify(result) .subs(sp.Function("f"), sp.sin) .subs({"x": value}) .evalf() - ), - float( - -sp.Derivative(sp.sin("x")) - .as_finite_difference(1e-3) - .subs({"x": value}) - .evalf() - ), + ) + ) == float( + -sp.Derivative(sp.sin("x")) + .as_finite_difference(1e-3) + .subs({"x": value}) + .evalf() ) with pytest.raises(ValueError): differentiate2c( From f725f05e023bdf9e2e52ed351868a65ea1c3e9c3 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 15:05:10 +0200 Subject: [PATCH 22/74] Isolate symbol-making --- python/nmodl/ode.py | 10 ++++++---- test/unit/ode/test_ode.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 4e5b9be253..deedc8da63 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -247,6 +247,11 @@ def _interweave_eqs(F, J): return code +def make_symbol(var, /): + """Create SymPy symbol from a variable.""" + return sp.Symbol(var, real=True) if isinstance(var, str) else var + + def solve_lin_system( eq_strings, vars, @@ -608,10 +613,7 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables - sympy_vars = { - str(var): (sp.symbols(var, real=True) if isinstance(var, str) else var) - for var in vars - } + sympy_vars = {str(var): make_symbol(var) for var in vars} sympy_vars[dependent_var] = x # parse string into SymPy equation diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 33810c16da..991de4c8fb 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -3,13 +3,21 @@ # # SPDX-License-Identifier: Apache-2.0 -from nmodl.ode import differentiate2c, integrate2c +from nmodl.ode import differentiate2c, integrate2c, make_symbol import sympy as sp +def make_symbols(iterable): + return [make_symbol(arg) for arg in iterable] + + def _equivalent( - lhs, rhs, vars=["a", "b", "c", "d", "e", "f", "v", "w", "x", "y", "z", "t", "dt"] + lhs, + rhs, + vars=make_symbols( + ["a", "b", "c", "d", "e", "f", "v", "w", "x", "y", "z", "t", "dt"] + ), ): """Helper function to test equivalence of analytic expressions Analytic expressions can often be written in many different, From 6658bf4364607794bda59610832da78a6e1bf48a Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 15:06:29 +0200 Subject: [PATCH 23/74] Forgot one --- test/unit/ode/test_ode.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 991de4c8fb..f855c7d0ee 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -36,10 +36,7 @@ def _equivalent( """ lhs = lhs.replace("pow(", "Pow(") rhs = rhs.replace("pow(", "Pow(") - sympy_vars = { - str(var): (sp.symbols(var, real=True) if isinstance(var, str) else var) - for var in vars - } + sympy_vars = {str(var): make_symbol(var) for var in vars} for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)): eq_l = sp.sympify(l, locals=sympy_vars) eq_r = sp.sympify(r, locals=sympy_vars) From e0c087c7170d7aa96f6f89f398a09bb6a88ca79a Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 16:21:13 +0200 Subject: [PATCH 24/74] Leave out KINETIC block for now Due to the CONSERVE statement (among others), the KINETIC block needs special handling when generating code for CVODE. --- src/main.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index a4ea9e266b..1b059f3f1a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -434,6 +434,19 @@ int run_nmodl(int argc, const char* argv[]) { ast_to_nmodl(*ast, filepath("unroll")); SymtabVisitor(update_symtab).visit_program(*ast); } + if (neuron_code) { + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() + .api() + .initialize_interpreter(); + logger->info("Running CVODE visitor"); + CvodeVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("cvode")); + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() + .api() + .finalize_interpreter(); + } + /// note that we can not symtab visitor in update mode as we /// replace kinetic block with derivative block of same name @@ -499,18 +512,11 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_sparse = solver_exists(*ast, "sparse"); if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || - sympy_linear || neuron_code) { + sympy_linear) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); - if (neuron_code) { - logger->info("Running CVODE visitor"); - CvodeVisitor().visit_program(*ast); - SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("cvode")); - } - if (sympy_conductance) { logger->info("Running sympy conductance visitor"); SympyConductanceVisitor().visit_program(*ast); From 2470dff63cf804f471829233af8d8cfaa0dd821e Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 16:29:05 +0200 Subject: [PATCH 25/74] Add missing import --- test/unit/ode/test_ode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index d0a3a0e46c..69e345e171 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 from nmodl.ode import differentiate2c, integrate2c, make_symbol +import pytest import sympy as sp From 4fde9295ef4c00e79322d3fe14c379715efc59d4 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 7 Oct 2024 09:40:27 +0200 Subject: [PATCH 26/74] Put back Python 3.8 for now --- .github/workflows/nmodl-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nmodl-ci.yml b/.github/workflows/nmodl-ci.yml index 3d981241cc..09a2a4d20d 100644 --- a/.github/workflows/nmodl-ci.yml +++ b/.github/workflows/nmodl-ci.yml @@ -16,7 +16,7 @@ on: env: CTEST_PARALLEL_LEVEL: 1 - PYTHON_VERSION: 3.9 + PYTHON_VERSION: 3.8 DESIRED_CMAKE_VERSION: 3.15.0 jobs: From a08df258e71a8105f26f294bc0b23978c70bd8bd Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 7 Oct 2024 13:37:27 +0200 Subject: [PATCH 27/74] Remove remaining occurrences of `DerivativeOriginalVisitor` --- src/visitors/cvode_visitor.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 0ed1a97573..da053061d2 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -36,7 +36,7 @@ static auto get_name_map(const ast::Expression& node, const std::string& name) { return var->get_node_name() == item; })) { logger->debug( - "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "CvodeVisitor :: adding INDEXED_VARIABLE {} to " "node_map", var->get_node_name()); name_map[var->get_node_name()] = get_index( @@ -103,13 +103,12 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { auto [jacobian, exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); if (!exception_message.empty()) { - logger->warn("DerivativeOriginalVisitor :: python exception: {}", - exception_message); + logger->warn("CvodeVisitor :: python exception: {}", exception_message); } // NOTE: LHS can be anything here, the equality is to keep `create_statement` from // complaining, we discard the LHS later auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); - logger->debug("DerivativeOriginalVisitor :: replacing statement {} with {}", + logger->debug("CvodeVisitor :: replacing statement {} with {}", to_nmodl(node), statement); auto expr_statement = std::dynamic_pointer_cast( From 9fee9a8d42092724bf69c125fd85688a75e88686 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 8 Oct 2024 11:50:04 +0200 Subject: [PATCH 28/74] WIP on CONSERVE --- src/visitors/cvode_visitor.cpp | 40 ++++++++++++++++++++++++++++++++++ src/visitors/cvode_visitor.hpp | 4 ++++ 2 files changed, 44 insertions(+) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index da053061d2..7cb7439e6a 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -24,6 +24,20 @@ static int get_index(const ast::IndexedName& node) { return std::stoi(to_nmodl(node.get_length())); } +void CvodeVisitor::visit_conserve(ast::Conserve& node) { + logger->debug("CvodeVisitor :: CONSERVE statement: {}", to_nmodl(node)); + std::string conserve_equation_statevar; + if (node.get_react()->is_react_var_name()) { + conserve_equation_statevar = node.get_react()->get_node_name(); + } + auto conserve_equation_str = to_nmodl(*node.get_expr()); + logger->debug("CvodeVisitor :: --> replace ODE for state var {} with equation {}", + conserve_equation_statevar, + conserve_equation_str); + conserve_equations[conserve_equation_statevar] = conserve_equation_str; +} + + static auto get_name_map(const ast::Expression& node, const std::string& name) { std::unordered_map name_map; // all of the "reserved" symbols @@ -95,6 +109,28 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } + // case: there is a variable being CONSERVEd, but it's not the current one + if (!conserve_equations.empty() && !conserve_equations.count(name->get_node_name())) { + auto rhs = node.get_rhs(); + auto nodes = collect_nodes(*node.get_rhs(), {ast::AstNodeType::VAR_NAME}); + for (auto& n: nodes) { + if (conserve_equations.count(n->get_node_name())) { + auto statement = fmt::format("{} = {}", n->get_node_name(), conserve_equations[n->get_node_name()]); + logger->debug("CvodeVisitor :: replacing CONSERVEd variable {} with {} in {}", + n->get_node_name(), + conserve_equations[n->get_node_name()], + to_nmodl(*node.get_rhs())); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + auto thing = std::shared_ptr(bin_expr->get_rhs()->clone()); + n = std::move(std::dynamic_pointer_cast(thing)); + std::cout << to_nmodl(*n) << std::endl; + } + } + } + std::cout << to_nmodl(node) << std::endl; if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) @@ -131,6 +167,10 @@ void CvodeVisitor::visit_program(ast::Program& node) { node.emplace_back_node(der_node); } + for (const auto& [key, value]: conserve_equations) { + std::cout << key << ", " << value << std::endl; + } + // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index 7dcef42839..bcd561e973 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -55,6 +55,9 @@ class CvodeVisitor: public AstVisitor { /// index of the block to modify BlockIndex block_index = BlockIndex::FUNCTION; + /// map of state vars to conserve equations + std::unordered_map conserve_equations; + public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; @@ -62,6 +65,7 @@ class CvodeVisitor: public AstVisitor { void visit_diff_eq_expression(ast::DiffEqExpression& node) override; void visit_binary_expression(ast::BinaryExpression& node) override; void visit_statement_block(ast::StatementBlock& node) override; + void visit_conserve(ast::Conserve& node) override; }; /** \} */ // end of visitor_classes From d98fcc00d5e21947ebd1aa98a68afc9a57a7d8a5 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 9 Oct 2024 16:54:24 +0200 Subject: [PATCH 29/74] Ignore CONSERVE equations They are just hints to the NMODL compiler, but they are not at all necessary to use when solving the ODEs. --- src/visitors/cvode_visitor.cpp | 42 ++++++---------------------------- src/visitors/cvode_visitor.hpp | 5 ++-- 2 files changed, 10 insertions(+), 37 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 7cb7439e6a..0f87c2242a 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -25,16 +25,11 @@ static int get_index(const ast::IndexedName& node) { } void CvodeVisitor::visit_conserve(ast::Conserve& node) { - logger->debug("CvodeVisitor :: CONSERVE statement: {}", to_nmodl(node)); - std::string conserve_equation_statevar; - if (node.get_react()->is_react_var_name()) { - conserve_equation_statevar = node.get_react()->get_node_name(); + if (in_cvode_block) { + logger->warn("CvodeVisitor :: CONSERVE statement {} will be ignored in CVODE codegen", + to_nmodl(node)); + conserve_equations.emplace(&node); } - auto conserve_equation_str = to_nmodl(*node.get_expr()); - logger->debug("CvodeVisitor :: --> replace ODE for state var {} with equation {}", - conserve_equation_statevar, - conserve_equation_str); - conserve_equations[conserve_equation_statevar] = conserve_equation_str; } @@ -109,28 +104,6 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } - // case: there is a variable being CONSERVEd, but it's not the current one - if (!conserve_equations.empty() && !conserve_equations.count(name->get_node_name())) { - auto rhs = node.get_rhs(); - auto nodes = collect_nodes(*node.get_rhs(), {ast::AstNodeType::VAR_NAME}); - for (auto& n: nodes) { - if (conserve_equations.count(n->get_node_name())) { - auto statement = fmt::format("{} = {}", n->get_node_name(), conserve_equations[n->get_node_name()]); - logger->debug("CvodeVisitor :: replacing CONSERVEd variable {} with {} in {}", - n->get_node_name(), - conserve_equations[n->get_node_name()], - to_nmodl(*node.get_rhs())); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - auto thing = std::shared_ptr(bin_expr->get_rhs()->clone()); - n = std::move(std::dynamic_pointer_cast(thing)); - std::cout << to_nmodl(*n) << std::endl; - } - } - } - std::cout << to_nmodl(node) << std::endl; if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) @@ -167,12 +140,11 @@ void CvodeVisitor::visit_program(ast::Program& node) { node.emplace_back_node(der_node); } - for (const auto& [key, value]: conserve_equations) { - std::cout << key << ", " << value << std::endl; - } - // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); + if (!conserve_equations.empty()) { + node.erase_node(conserve_equations); + } } } // namespace visitor diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index bcd561e973..f324b2b075 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -15,6 +15,7 @@ #include "symtab/decl.hpp" #include "visitors/ast_visitor.hpp" #include +#include namespace nmodl { namespace visitor { @@ -55,8 +56,8 @@ class CvodeVisitor: public AstVisitor { /// index of the block to modify BlockIndex block_index = BlockIndex::FUNCTION; - /// map of state vars to conserve equations - std::unordered_map conserve_equations; + /// list of conserve equations encountered + std::unordered_set conserve_equations; public: void visit_derivative_block(ast::DerivativeBlock& node) override; From 321cdb395e874bca49df4475b5a7dd4610cbb847 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 10:35:22 +0200 Subject: [PATCH 30/74] Add documentation --- docs/contents/cvode.rst | 74 +++++++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + 2 files changed, 75 insertions(+) create mode 100644 docs/contents/cvode.rst diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst new file mode 100644 index 0000000000..87ecdb309c --- /dev/null +++ b/docs/contents/cvode.rst @@ -0,0 +1,74 @@ +Variable timestep integration (CVODE) +===================================== + +As opposed to fixed timestep integration, variable timestep integration (CVODE +in NEURON parlance) uses the SUNDIALS package to solve a ``DERIVATIVE`` or +``KINETIC`` block using a variable timestep. This allows for faster computation +times if the function in question does not vary too wildly. + +Implementation in NMODL +----------------------- + +The code generation for CVODE is activated only if exactly one of the following +is satisfied: + +1. there is one ``KINETIC`` block in the mod file +2. there is one ``DERIVATIVE`` block in the mod file +3. a ``PROCEDURE`` block is solved with the ``after_cvode``, ``cvode_t``, or + ``cvode_t_v`` methods + +In NMODL, all ``KINETIC`` blocks are internally first converted to +``DERIVATIVE`` blocks. The ``DERIVATIVE`` block is then converted to a +``CVODE`` block, which contains two parts; the first part contains the update +step for linear systems, while the second part contains the update step for +non-linear systems (see `CVODES documentation`_, eqs. (4.8) and (4.9)). Given +a ``DERIVATIVE`` block of the form: + +.. _CVODES documentation: https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html + +.. code-block:: + + DERIVATIVE state { + x_i' = f(x_1, ..., x_n) + } + +the structure of the ``CVODE`` block is then roughly: + +.. code-block:: + + CVODE state { + Dx_i = f_i(x_1, ..., x_n) + }{ + Dx_i = Dx_i / (1 - dt * J_ii(f)) + } + +where ``J_ii(f)`` is the diagonal part of the Jacobian, i.e. + +.. math:: + + J_{ii}(f) = \frac{ \partial f_i(x_1, \ldots, x_n) }{\partial x_i} + +As an example, consider the following ``DERIVATIVE`` +block: + +.. code-block:: + + DERIVATIVE state { + X' = - X + } + +Where ``X`` is a ``STATE`` variable with some initial value, specified in the +``INITIAL`` block. The corresponding ``CVODE`` block is then: + +.. code-block:: + + CVODE state { + DX = - X + }{ + DX = DX / (1 - dt * (-1)) + } + + +**NOTE**: in case there are ``CONSERVE`` statements in ``KINETIC`` blocks, as +they are merely hints to NMODL, and have no impact on the results, they are +removed from ``CVODE`` blocks before the codegen stage. diff --git a/docs/index.rst b/docs/index.rst index 9c4b0105ee..15125ef4a6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,6 +24,7 @@ About NMODL contents/pointers contents/cable_equations contents/globals + contents/cvode .. toctree:: :maxdepth: 3 From 2984e46f25ff44a5af09c5f77262982cbb85803f Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 10:47:13 +0200 Subject: [PATCH 31/74] Really delete CONSERVE statements this time --- src/visitors/cvode_visitor.cpp | 7 +++++-- src/visitors/cvode_visitor.hpp | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 0f87c2242a..302b541d5d 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -26,7 +26,7 @@ static int get_index(const ast::IndexedName& node) { void CvodeVisitor::visit_conserve(ast::Conserve& node) { if (in_cvode_block) { - logger->warn("CvodeVisitor :: CONSERVE statement {} will be ignored in CVODE codegen", + logger->warn("CvodeVisitor :: statement {} will be ignored in CVODE codegen", to_nmodl(node)); conserve_equations.emplace(&node); } @@ -143,7 +143,10 @@ void CvodeVisitor::visit_program(ast::Program& node) { // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); if (!conserve_equations.empty()) { - node.erase_node(conserve_equations); + auto blocks = collect_nodes(node, {ast::AstNodeType::CVODE_BLOCK}); + auto block = std::dynamic_pointer_cast(blocks[0]); + block->get_function_block()->erase_statement(conserve_equations); + block->get_diagonal_jacobian_block()->erase_statement(conserve_equations); } } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index f324b2b075..716600be20 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -57,7 +57,7 @@ class CvodeVisitor: public AstVisitor { BlockIndex block_index = BlockIndex::FUNCTION; /// list of conserve equations encountered - std::unordered_set conserve_equations; + std::unordered_set conserve_equations; public: void visit_derivative_block(ast::DerivativeBlock& node) override; From 313330b29c398605147b8b3b94f61fec29830a95 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 10:50:15 +0200 Subject: [PATCH 32/74] Add test for CONSERVE statement --- test/unit/visitor/cvode.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index 3a57d242d4..a9f11a4219 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -37,6 +37,7 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { STATE {x z} DERIVATIVE equation { + CONSERVE x + z = 5 x' = -x + z * z z' = z * x } @@ -49,6 +50,10 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); REQUIRE(primed_vars.empty()); } + THEN("No CONSERVE statements are present in the CVODE block") { + auto conserved_stmts = collect_nodes(*block[0], {ast::AstNodeType::CONSERVE}); + REQUIRE(conserved_stmts.empty()); + } } } } From 03b40e842a190249034e148051a1c41f07486264 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 11:45:56 +0200 Subject: [PATCH 33/74] Fix variable naming --- test/unit/visitor/cvode.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index a9f11a4219..8ed1f595f6 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -44,14 +44,14 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { )"; auto ast = run_cvode_visitor(nmodl_text); THEN("CVODE block is added") { - auto block = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); - REQUIRE(!block.empty()); + auto blocks = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); + REQUIRE(blocks.size() == 1); THEN("No primed variables exist in the CVODE block") { - auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + auto primed_vars = collect_nodes(*blocks[0], {ast::AstNodeType::PRIME_NAME}); REQUIRE(primed_vars.empty()); } THEN("No CONSERVE statements are present in the CVODE block") { - auto conserved_stmts = collect_nodes(*block[0], {ast::AstNodeType::CONSERVE}); + auto conserved_stmts = collect_nodes(*blocks[0], {ast::AstNodeType::CONSERVE}); REQUIRE(conserved_stmts.empty()); } } From 8cd6a39b31c1fcdcb4fdd9a7496f541c54ff44c5 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 12:51:19 +0200 Subject: [PATCH 34/74] Put back the right one --- src/main.cpp | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index 2901c5f650..97045dad7d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -430,19 +430,6 @@ int run_nmodl(int argc, const char* argv[]) { ast_to_nmodl(*ast, filepath("unroll")); SymtabVisitor(update_symtab).visit_program(*ast); } - if (neuron_code) { - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() - .api() - .initialize_interpreter(); - logger->info("Running CVODE visitor"); - CvodeVisitor().visit_program(*ast); - SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("cvode")); - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() - .api() - .finalize_interpreter(); - } - /// note that we can not symtab visitor in update mode as we /// replace kinetic block with derivative block of same name @@ -508,11 +495,18 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_sparse = solver_exists(*ast, "sparse"); if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || - sympy_linear) { + sympy_linear || neuron_code) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); + if (neuron_code) { + logger->info("Running CVODE visitor"); + CvodeVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("cvode")); + } + if (sympy_conductance) { logger->info("Running sympy conductance visitor"); SympyConductanceVisitor().visit_program(*ast); From fc4fe9d994ca7fb9b1ea4e60b789a02f96326606 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 12:54:20 +0200 Subject: [PATCH 35/74] I don't need this --- test/unit/ode/test_ode.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 6af88f763a..6eae706998 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -9,16 +9,8 @@ import sympy as sp -def make_symbols(iterable): - return [make_symbol(arg) for arg in iterable] - - def _equivalent( - lhs, - rhs, - vars=make_symbols( - ["a", "b", "c", "d", "e", "f", "v", "w", "x", "y", "z", "t", "dt"] - ), + lhs, rhs, vars=["a", "b", "c", "d", "e", "f", "v", "w", "x", "y", "z", "t", "dt"] ): """Helper function to test equivalence of analytic expressions Analytic expressions can often be written in many different, From bf3b0e68e8096603c42c1ac366d05670e1b9ffc2 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 13:31:06 +0200 Subject: [PATCH 36/74] Fixups --- src/codegen/codegen_neuron_cpp_visitor.cpp | 101 ++++++++++++--------- src/codegen/codegen_neuron_cpp_visitor.hpp | 18 ++++ 2 files changed, 76 insertions(+), 43 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index e9904607de..5ae082f699 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -859,7 +859,6 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in if (info.emit_cvode) { printer->add_line("static Symbol** _atollist;"); printer->push_block("static HocStateTolerance _hoc_state_tol[] ="); - // TODO: add stuff that iterates over `rangestate` in NOCMODL printer->add_line("{0, 0}"); printer->pop_block(";"); } @@ -1290,7 +1289,6 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { if (info.emit_cvode) { mech_register_args.push_back( - // TODO: figure out why the first parameter should be called "_cvode_ieq" fmt::format("_nrn_mechanism_field{{\"_cvode_ieq\", \"cvodeieq\"}} /* {} */", codegen_int_variables_size)); } @@ -1367,7 +1365,8 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { printer->fmt_line("hoc_register_dparam_semantics(mech_type, {}, \"cvodeieq\");", codegen_int_variables_size); printer->fmt_line( - "hoc_register_cvode(mech_type, ode_count_{}, ode_map_{}, ode_spec_{}, ode_matsol_{});", + "hoc_register_cvode(mech_type, ode_count_{}, ode_setup_tolerance_{}, " + "ode_setup_nonstiff_{}, ode_setup_stiff_{});", info.mod_suffix, info.mod_suffix, info.mod_suffix, @@ -2432,6 +2431,29 @@ void CodegenNeuronCppVisitor::print_net_receive_common_code() { printer->add_line("double t = nt->_t;"); } +CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::cvode_setup_parameters() { + return {{"", "const _nrn_model_sorted_token&", "", "_sorted_token"}, + {"", "NrnThread*", "", "nt"}, + {"", "Memb_list*", "", "_ml_arg"}, + {"", "int", "", "_type"}}; +} + +CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::cvode_update_parameters() { + ParamVector args = {{"", "_nrn_mechanism_cache_range&", "", "_lmc"}, + {"", fmt::format("{}_Instance&", info.mod_suffix), "", "inst"}, + {"", fmt::format("{}_NodeData&", info.mod_suffix), "", "node_data"}, + {"", "size_t", "", "id"}, + {"", "Datum*", "", "_ppvar"}, + {"", "Datum*", "", "_thread"}, + {"", "NrnThread*", "", "nt"}}; + + if (info.thread_callback_register) { + auto type_name = fmt::format("{}&", thread_variables_struct()); + args.emplace_back("", type_name, "", "_thread_vars"); + } + return args; +} + void CodegenNeuronCppVisitor::print_cvode_definitions() { if (!info.emit_cvode) { return; @@ -2447,28 +2469,13 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_newline(2); - const ParamVector args_setup = {{"", "const _nrn_model_sorted_token&", "", "_sorted_token"}, - {"", "NrnThread*", "", "nt"}, - {"", "Memb_list*", "", "_ml_arg"}, - {"", "int", "", "_type"}}; - - ParamVector args_cvode = {{"", "_nrn_mechanism_cache_range&", "", "_lmc"}, - {"", fmt::format("{}_Instance&", info.mod_suffix), "", "inst"}, - {"", fmt::format("{}_NodeData&", info.mod_suffix), "", "node_data"}, - {"", "size_t", "", "id"}, - {"", "Datum*", "", "_ppvar"}, - {"", "Datum*", "", "_thread"}, - {"", "NrnThread*", "", "nt"}}; + auto update_nonstiff_name = fmt::format("ode_update_nonstiff_{}", info.mod_suffix); - if (info.thread_callback_register) { - auto type_name = fmt::format("{}&", thread_variables_struct()); - args_cvode.emplace_back("", type_name, "", "_thread_vars"); - } - - /* The internal spec function */ - printer->fmt_push_block("static int ode_spec1_{}({})", - info.mod_suffix, - get_parameter_str(args_cvode)); // begin function definition + /* The update function for non-stiff systems */ + printer->fmt_push_block("static int {}({})", + update_nonstiff_name, + get_parameter_str(cvode_update_parameters())); // begin function + // definition printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); if (info.cvode_block) { @@ -2481,10 +2488,13 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_newline(2); - /* Main spec function */ - printer->push_block(fmt::format("static void ode_spec_{}({})", - info.mod_suffix, - get_parameter_str(args_setup))); // begin function definition + auto setup_nonstiff_name = fmt::format("ode_setup_nonstiff_{}", info.mod_suffix); + + /* The setup function for non-stiff systems */ + printer->push_block( + fmt::format("static void {}({})", + setup_nonstiff_name, + get_parameter_str(cvode_setup_parameters()))); // begin function definition printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); printer->add_line("auto nodecount = _ml_arg->nodecount;"); @@ -2499,16 +2509,16 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); - printer->fmt_line("ode_spec1_{}({});", info.mod_suffix, get_arg_str(args_cvode)); + printer->fmt_line("{}({});", update_nonstiff_name, get_arg_str(cvode_update_parameters())); printer->pop_block(); // end for loop printer->pop_block(); // end function definition printer->add_newline(2); - /* map */ + /* The function for setup of tolerance */ printer->push_block( - fmt::format("static void ode_map_{}(Prop* _prop, int equation_index, " + fmt::format("static void ode_setup_tolerance_{}(Prop* _prop, int equation_index, " "neuron::container::data_handle* _pv, " "neuron::container::data_handle* _pvdot, double* _atol, int _type)", info.mod_suffix)); // begin function definition @@ -2524,14 +2534,15 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_newline(2); - /* matsol instance (?) */ - printer->push_block(fmt::format("static void ode_matsol_instance1_{}({})", - info.mod_suffix, - get_parameter_str(args_cvode))); // begin function definition + auto update_stiff_name = fmt::format("ode_update_stiff_{}", info.mod_suffix); + + /* The update function for stiff systems */ + printer->push_block( + fmt::format("static void {}({})", + update_stiff_name, + get_parameter_str(cvode_update_parameters()))); // begin function definition if (info.cvode_block) { - // for mathematical details, see eq. (4.8) in: - // https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html auto block = info.cvode_block->get_diagonal_jacobian_block(); print_statement_block(*block, false, false); } @@ -2540,26 +2551,30 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_newline(2); - /* matsol */ - printer->push_block(fmt::format("static void ode_matsol_{}({})", - info.mod_suffix, - get_parameter_str(args_setup))); // begin function definition + auto setup_stiff_name = fmt::format("ode_setup_stiff_{}", info.mod_suffix); + + /* The setup function for stiff systems */ + printer->push_block( + fmt::format("static void {}({})", + setup_stiff_name, + get_parameter_str(cvode_setup_parameters()))); // begin function definition printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); printer->add_line("auto nodecount = _ml_arg->nodecount;"); printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); printer->add_line("auto* _thread = _ml_arg->_thread;"); + if (!codegen_thread_variables.empty()) { printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", thread_variables_struct(), info.thread_var_thread_id); } + printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for loop printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); - // TODO check if this can be replaced by ode_matsol1 - printer->fmt_line("ode_matsol_instance1_{}({});", info.mod_suffix, get_arg_str(args_cvode)); + printer->fmt_line("{}({});", update_stiff_name, get_arg_str(cvode_update_parameters())); printer->pop_block(); // end for loop printer->pop_block(); // end function definition diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index 5078fd92cf..3cfb45b466 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -701,6 +701,24 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_ion_variable() override; + /** + * Get the parameters for functions that setup (initialize) CVODE + * + */ + ParamVector cvode_setup_parameters(); + + + /** + * Get the parameters for functions that update state at given timestep in CVODE + * + */ + ParamVector cvode_update_parameters(); + + + /** + * Print all callbacks for CVODE + * + */ void print_cvode_definitions(); From 836ec7498d91675785ff8f29d48c1e3f8356451d Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 13:32:39 +0200 Subject: [PATCH 37/74] Update docstring --- docs/contents/cvode.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst index 87ecdb309c..d70378aaf6 100644 --- a/docs/contents/cvode.rst +++ b/docs/contents/cvode.rst @@ -20,9 +20,10 @@ is satisfied: In NMODL, all ``KINETIC`` blocks are internally first converted to ``DERIVATIVE`` blocks. The ``DERIVATIVE`` block is then converted to a ``CVODE`` block, which contains two parts; the first part contains the update -step for linear systems, while the second part contains the update step for -non-linear systems (see `CVODES documentation`_, eqs. (4.8) and (4.9)). Given -a ``DERIVATIVE`` block of the form: +step for non-stiff systems (functional iteration), while the second part +contains the update step for stiff systems (additional step using the +Jacobian). For more information, see `CVODES documentation`_, eqs. (4.8) and +(4.9)). Given a ``DERIVATIVE`` block of the form: .. _CVODES documentation: https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html From 9f6b751b5c70d90742a406284dceb2d2ab1738fd Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 13:46:34 +0200 Subject: [PATCH 38/74] Fix typo --- docs/contents/cvode.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst index d70378aaf6..367e160d48 100644 --- a/docs/contents/cvode.rst +++ b/docs/contents/cvode.rst @@ -23,7 +23,7 @@ In NMODL, all ``KINETIC`` blocks are internally first converted to step for non-stiff systems (functional iteration), while the second part contains the update step for stiff systems (additional step using the Jacobian). For more information, see `CVODES documentation`_, eqs. (4.8) and -(4.9)). Given a ``DERIVATIVE`` block of the form: +(4.9). Given a ``DERIVATIVE`` block of the form: .. _CVODES documentation: https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html From 1348ab946456a3341ca9a8253a71717170214045 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 13:58:54 +0200 Subject: [PATCH 39/74] Update docstring --- src/pybind/wrapper.cpp | 2 -- src/pybind/wrapper.hpp | 6 ++++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index ae9d414976..65fe2b6d6b 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -186,8 +186,6 @@ except Exception as e: return {std::move(solution), std::move(exception_message)}; } -/// \brief A blunt instrument that differentiates expression w.r.t. variable -/// \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, const std::string& variable, diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index b4ec0a2dff..694b0143d7 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -45,6 +45,12 @@ std::tuple call_analytic_diff( const std::vector& expressions, const std::set& used_names_in_block); + +/// \brief Differentiates an expression with respect to a variable +/// \param expression The expression we want to differentiate +/// \param variable The name of the independent variable we are differentiating against +/// \param index_vars A map of array (indexable) variables (and their associated indices) that +/// appear in \ref expression \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, const std::string& variable, From 2ae9db2f105a6fca0a4bb2583d05debdf89ad163 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 16:19:04 +0200 Subject: [PATCH 40/74] Remove unused code --- test/usecases/cvode/derivative.mod | 11 +---------- test/usecases/cvode/test_cvode.py | 5 ----- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/test/usecases/cvode/derivative.mod b/test/usecases/cvode/derivative.mod index 70aa20352c..d3715352ff 100644 --- a/test/usecases/cvode/derivative.mod +++ b/test/usecases/cvode/derivative.mod @@ -12,20 +12,14 @@ PARAMETER { v5 = 0.3 r = 3 k = 0.2 - nmodl_alpha = 1.2 - nmodl_beta = 4.5 - nmodl_gamma = 2.4 - nmodl_delta = 7.5 } -STATE {var1 var2 var3 var4 var5} +STATE {var1 var2 var3} INITIAL { var1 = v1 var2 = v2 var3 = v3 - var4 = v4 - var5 = v5 } BREAKPOINT { @@ -40,7 +34,4 @@ DERIVATIVE equation { var2' = -var2 * a : logistic ODE var3' = r * var3 * (1 - var3 / k) - : system of 2 ODEs (predator-prey model) - var4' = nmodl_alpha * var4 - nmodl_beta * var4 * var5 - var5' = nmodl_delta * var4 * var5 - nmodl_gamma * var5 } diff --git a/test/usecases/cvode/test_cvode.py b/test/usecases/cvode/test_cvode.py index 521b93f673..c5ae18cc3e 100644 --- a/test/usecases/cvode/test_cvode.py +++ b/test/usecases/cvode/test_cvode.py @@ -50,8 +50,3 @@ def simulate(rtol): if __name__ == "__main__": t, *x = simulate(rtol=1e-5) - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(nrows=len(x)) - # for a, val in zip(ax, x): - # a.plot(t, val, ls="", marker="x", markersize=0.1) - # plt.show() From ee9c187afd155d3d0193ecfbdf7c1349b5cf7785 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 13:02:34 +0200 Subject: [PATCH 41/74] Add option for diffing IndexedName --- python/nmodl/ode.py | 2 +- src/pybind/wrapper.cpp | 19 +++++++++++---- src/pybind/wrapper.hpp | 3 ++- src/visitors/cvode_visitor.cpp | 43 ++++++++++++++++++++++++---------- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index cd6b2b27ae..8219169db8 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -619,7 +619,7 @@ def differentiate2c( # every symbol (a.k.a variable) that SymPy # is going to manipulate needs to be declared # explicitly - x = sp.symbols(dependent_var, real=True) + x = make_symbol(dependent_var) vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 65fe2b6d6b..8385954ed7 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -10,6 +10,7 @@ #include "codegen/codegen_naming.hpp" #include "pybind/pyembed.hpp" #include +#include #include #include @@ -188,17 +189,26 @@ except Exception as e: std::tuple call_diff2c( const std::string& expression, - const std::string& variable, + const std::pair>& variable, const std::unordered_map& indexed_vars) { std::string statements; // only indexed variables require special treatment for (const auto& [var, prop]: indexed_vars) { statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var); } - auto locals = py::dict("expression"_a = expression, "variable"_a = variable); - std::string script = fmt::format(R"( + auto [name, property] = variable; + if (property.has_value()) { + name = fmt::format("sp.IndexedBase('{}', shape=[1])", name); + statements += fmt::format("_allvars.append({})", name); + } else { + name = fmt::format("'{}'", name); + } + auto locals = py::dict("expression"_a = expression); + std::string script = + fmt::format(R"( _allvars = [] {} +variable = {} exception_message = "" try: solution = differentiate2c(expression, @@ -210,7 +220,8 @@ except Exception as e: solution = "" exception_message = str(e) )", - statements); + statements, + property.has_value() ? fmt::format("{}[{}]", name, property.value()) : name); py::exec(nmodl::pybind_wrappers::ode_py + script, locals); diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index 694b0143d7..aad85aef25 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -7,6 +7,7 @@ #pragma once +#include #include #include #include @@ -53,7 +54,7 @@ std::tuple call_analytic_diff( /// appear in \ref expression \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, - const std::string& variable, + const std::pair>& variable, const std::unordered_map& indexed_vars = {}); struct pybind_wrap_api { diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 302b541d5d..f3979474fb 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -24,6 +24,16 @@ static int get_index(const ast::IndexedName& node) { return std::stoi(to_nmodl(node.get_length())); } +static std::pair> parse_independent_var( + std::shared_ptr node) { + auto variable = std::make_pair(node->get_node_name(), std::optional()); + if (node->is_indexed_name()) { + variable.second = std::optional( + get_index(*std::dynamic_pointer_cast(node))); + } + return variable; +} + void CvodeVisitor::visit_conserve(ast::Conserve& node) { if (in_cvode_block) { logger->warn("CvodeVisitor :: statement {} will be ignored in CVODE codegen", @@ -92,25 +102,32 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { auto name = std::dynamic_pointer_cast(lhs)->get_name(); - if (name->is_prime_name()) { - auto varname = "D" + name->get_node_name(); - logger->debug("CvodeVisitor :: replacing {} with {} on LHS of {}", - name->get_node_name(), - varname, - to_nmodl(node)); - node.set_lhs(std::make_shared(new ast::String(varname))); - if (program_symtab->lookup(varname) == nullptr) { - auto symbol = std::make_shared(varname, ModToken()); - symbol->set_original_name(name->get_node_name()); - program_symtab->insert(symbol); + if (name->is_prime_name() || name->is_indexed_name()) { + std::string varname; + if (name->is_prime_name()) { + varname = "D" + name->get_node_name(); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } + } else { + varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); + auto statement = fmt::format("{} = {}", varname, varname); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); } if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) auto name_map = get_name_map(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; - auto [jacobian, - exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); + auto [jacobian, exception_message] = + diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); if (!exception_message.empty()) { logger->warn("CvodeVisitor :: python exception: {}", exception_message); } From 54a480e0da3be17961d7cae4172f775f58e10793 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 16:58:50 +0200 Subject: [PATCH 42/74] Refactor --- src/visitors/cvode_visitor.cpp | 216 +++++++++++++++++++-------------- src/visitors/cvode_visitor.hpp | 35 ------ 2 files changed, 126 insertions(+), 125 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index f3979474fb..8d40764ea2 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -34,16 +34,8 @@ static std::pair> parse_independent_var( return variable; } -void CvodeVisitor::visit_conserve(ast::Conserve& node) { - if (in_cvode_block) { - logger->warn("CvodeVisitor :: statement {} will be ignored in CVODE codegen", - to_nmodl(node)); - conserve_equations.emplace(&node); - } -} - - -static auto get_name_map(const ast::Expression& node, const std::string& name) { +static std::unordered_map get_name_map(const ast::Expression& node, + const std::string& name) { std::unordered_map name_map; // all of the "reserved" symbols auto reserved_symbols = get_external_functions(); @@ -65,105 +57,149 @@ static auto get_name_map(const ast::Expression& node, const std::string& name) { return name_map; } -void CvodeVisitor::visit_derivative_block(ast::DerivativeBlock& node) { - node.visit_children(*this); - derivative_block = std::shared_ptr(node.clone()); -} +static std::string cvode_set_lhs(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); + auto name = std::dynamic_pointer_cast(lhs)->get_name(); -void CvodeVisitor::visit_cvode_block(ast::CvodeBlock& node) { - in_cvode_block = true; - node.visit_children(*this); - in_cvode_block = false; + std::string varname; + if (name->is_prime_name()) { + varname = "D" + name->get_node_name(); + node.set_lhs(std::make_shared(new ast::String(varname))); + } else if (name->is_indexed_name()) { + auto nodes = collect_nodes(*name, {ast::AstNodeType::PRIME_NAME}); + // make sure the LHS isn't just a plain indexed var + if (!nodes.empty()) { + varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); + auto statement = fmt::format("{} = {}", varname, varname); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + } + } + return varname; } -void CvodeVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { - in_differential_equation = true; - node.visit_children(*this); - in_differential_equation = false; -} +class CvodeHelperVisitor: public AstVisitor { + protected: + symtab::SymbolTable* program_symtab = nullptr; + bool in_differential_equation = false; + std::unordered_set conserve_equations; + public: + inline void visit_diff_eq_expression(ast::DiffEqExpression& node) { + in_differential_equation = true; + node.visit_children(*this); + in_differential_equation = false; + } +}; -void CvodeVisitor::visit_statement_block(ast::StatementBlock& node) { - node.visit_children(*this); - if (in_cvode_block) { - ++block_index; +class NonStiffVisitor: public CvodeHelperVisitor { + public: + NonStiffVisitor(symtab::SymbolTable* symtab) { + program_symtab = symtab; } -} + inline void visit_binary_expression(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); -void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { - const auto& lhs = node.get_lhs(); + if (!in_differential_equation || !lhs->is_var_name()) { + return; + } + + auto name = std::dynamic_pointer_cast(lhs)->get_name(); + auto varname = cvode_set_lhs(node); - /// we have to only solve ODEs under original derivative block where lhs is variable - if (!in_cvode_block || !in_differential_equation || !lhs->is_var_name()) { - return; + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } } +}; - auto name = std::dynamic_pointer_cast(lhs)->get_name(); +class StiffVisitor: public CvodeHelperVisitor { + public: + StiffVisitor(symtab::SymbolTable* symtab) { + program_symtab = symtab; + } - if (name->is_prime_name() || name->is_indexed_name()) { - std::string varname; - if (name->is_prime_name()) { - varname = "D" + name->get_node_name(); - node.set_lhs(std::make_shared(new ast::String(varname))); - if (program_symtab->lookup(varname) == nullptr) { - auto symbol = std::make_shared(varname, ModToken()); - symbol->set_original_name(name->get_node_name()); - program_symtab->insert(symbol); - } - } else { - varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); - auto statement = fmt::format("{} = {}", varname, varname); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + inline void visit_binary_expression(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); + + if (!in_differential_equation || !lhs->is_var_name()) { + return; } - if (block_index == BlockIndex::JACOBIAN) { - auto rhs = node.get_rhs(); - // map of all indexed symbols (need special treatment in SymPy) - auto name_map = get_name_map(*rhs, name->get_node_name()); - auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; - auto [jacobian, exception_message] = - diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); - if (!exception_message.empty()) { - logger->warn("CvodeVisitor :: python exception: {}", exception_message); - } - // NOTE: LHS can be anything here, the equality is to keep `create_statement` from - // complaining, we discard the LHS later - auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); - logger->debug("CvodeVisitor :: replacing statement {} with {}", - to_nmodl(node), - statement); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); + + auto name = std::dynamic_pointer_cast(lhs)->get_name(); + auto varname = cvode_set_lhs(node); + + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } + + auto rhs = node.get_rhs(); + // map of all indexed symbols (need special treatment in SymPy) + auto name_map = get_name_map(*rhs, name->get_node_name()); + auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; + auto [jacobian, + exception_message] = diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); + if (!exception_message.empty()) { + logger->warn("CvodeVisitor :: python exception: {}", exception_message); } + // NOTE: LHS can be anything here, the equality is to keep `create_statement` from + // complaining, we discard the LHS later + auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); + logger->debug("CvodeVisitor :: replacing statement {} with {}", to_nmodl(node), statement); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); } -} +}; + void CvodeVisitor::visit_program(ast::Program& node) { - program_symtab = node.get_symbol_table(); - node.visit_children(*this); - if (derivative_block) { - auto der_node = new ast::CvodeBlock(derivative_block->get_name(), - derivative_block->get_statement_block(), - std::shared_ptr( - derivative_block->get_statement_block()->clone())); - node.emplace_back_node(der_node); - } + auto der_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); + if (!der_blocks.empty()) { + auto der_block = std::dynamic_pointer_cast(der_blocks[0]); + + auto non_stiff_block = der_block->get_statement_block()->clone(); + { + auto conserve_equations = collect_nodes(*non_stiff_block, {ast::AstNodeType::CONSERVE}); + if (!conserve_equations.empty()) { + std::unordered_set eqs; + for (const auto& item: conserve_equations) { + eqs.insert(std::dynamic_pointer_cast(item).get()); + } + non_stiff_block->erase_statement(eqs); + } + } + + auto stiff_block = der_block->get_statement_block()->clone(); + { + auto conserve_equations = collect_nodes(*stiff_block, {ast::AstNodeType::CONSERVE}); + if (!conserve_equations.empty()) { + std::unordered_set eqs; + for (const auto& item: conserve_equations) { + eqs.insert(std::dynamic_pointer_cast(item).get()); + } + stiff_block->erase_statement(eqs); + } + } + - // re-visit the AST since we now inserted the CVODE block - node.visit_children(*this); - if (!conserve_equations.empty()) { - auto blocks = collect_nodes(node, {ast::AstNodeType::CVODE_BLOCK}); - auto block = std::dynamic_pointer_cast(blocks[0]); - block->get_function_block()->erase_statement(conserve_equations); - block->get_diagonal_jacobian_block()->erase_statement(conserve_equations); + NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); + StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); + node.emplace_back_node( + new ast::CvodeBlock(der_block->get_name(), + std::shared_ptr(non_stiff_block), + std::shared_ptr(stiff_block))); } } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index 716600be20..c8d37f6a61 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -20,16 +20,6 @@ namespace nmodl { namespace visitor { -enum class BlockIndex { FUNCTION = 0, JACOBIAN = 1 }; - -inline BlockIndex& operator++(BlockIndex& index) { - if (index == BlockIndex::FUNCTION) { - index = BlockIndex::JACOBIAN; - } else { - index = BlockIndex::FUNCTION; - } - return index; -} /** * \addtogroup visitor_classes * \{ @@ -40,33 +30,8 @@ inline BlockIndex& operator++(BlockIndex& index) { * \brief Visitor used for generating the necessary AST nodes for CVODE */ class CvodeVisitor: public AstVisitor { - private: - /// The copy of the derivative block of a given mod file - std::shared_ptr derivative_block = nullptr; - - /// true while visiting differential equation - bool in_differential_equation = false; - - /// global symbol table - symtab::SymbolTable* program_symtab = nullptr; - - /// true while we are visiting a CVODE block - bool in_cvode_block = false; - - /// index of the block to modify - BlockIndex block_index = BlockIndex::FUNCTION; - - /// list of conserve equations encountered - std::unordered_set conserve_equations; - public: - void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; - void visit_cvode_block(ast::CvodeBlock& node) override; - void visit_diff_eq_expression(ast::DiffEqExpression& node) override; - void visit_binary_expression(ast::BinaryExpression& node) override; - void visit_statement_block(ast::StatementBlock& node) override; - void visit_conserve(ast::Conserve& node) override; }; /** \} */ // end of visitor_classes From cefc1596beff987cb7a7bd9313773963983c63c5 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 21:29:26 +0200 Subject: [PATCH 43/74] Enable sympy if NEURON codegen --- src/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.cpp b/src/main.cpp index a32640276d..89c2813f23 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -523,7 +523,7 @@ int run_nmodl(int argc, const char* argv[]) { } - if (sympy_conductance || sympy_analytic) { + if (sympy_conductance || sympy_analytic || neuron_code) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); From a3e1c6c23a177abd92d34ac731d8abb03d42adca Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 21:49:57 +0200 Subject: [PATCH 44/74] Mark constructors as explicit --- src/visitors/cvode_visitor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 8d40764ea2..d64ba7d214 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -98,7 +98,7 @@ class CvodeHelperVisitor: public AstVisitor { class NonStiffVisitor: public CvodeHelperVisitor { public: - NonStiffVisitor(symtab::SymbolTable* symtab) { + explicit NonStiffVisitor(symtab::SymbolTable* symtab) { program_symtab = symtab; } @@ -122,7 +122,7 @@ class NonStiffVisitor: public CvodeHelperVisitor { class StiffVisitor: public CvodeHelperVisitor { public: - StiffVisitor(symtab::SymbolTable* symtab) { + explicit StiffVisitor(symtab::SymbolTable* symtab) { program_symtab = symtab; } From 659b018f7f39384851ca934d4c2b9c7840316517 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 15 Oct 2024 14:16:53 +0200 Subject: [PATCH 45/74] Update tests --- test/unit/visitor/cvode.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index 8ed1f595f6..e42df2ec45 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -34,12 +34,14 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { SUFFIX example } - STATE {x z} + STATE {X Y[2] Z} DERIVATIVE equation { - CONSERVE x + z = 5 - x' = -x + z * z - z' = z * x + CONSERVE X + Z = 5 + X' = -X + Z * Z + Z' = Z * X + Y'[1] = -Y[0] + Y'[0] = -Y[1] } )"; auto ast = run_cvode_visitor(nmodl_text); From bd376c3ab5bcc431d909a8a9f22562658c87db0a Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 15 Oct 2024 15:58:50 +0200 Subject: [PATCH 46/74] Remove code duplication --- src/visitors/cvode_visitor.cpp | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index d64ba7d214..e93b2b3fd0 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -24,6 +24,17 @@ static int get_index(const ast::IndexedName& node) { return std::stoi(to_nmodl(node.get_length())); } +static void remove_conserve_statements(ast::StatementBlock& node) { + auto conserve_equations = collect_nodes(node, {ast::AstNodeType::CONSERVE}); + if (!conserve_equations.empty()) { + std::unordered_set eqs; + for (const auto& item: conserve_equations) { + eqs.insert(std::dynamic_pointer_cast(item).get()); + } + node.erase_statement(eqs); + } +} + static std::pair> parse_independent_var( std::shared_ptr node) { auto variable = std::make_pair(node->get_node_name(), std::optional()); @@ -170,29 +181,10 @@ void CvodeVisitor::visit_program(ast::Program& node) { auto der_block = std::dynamic_pointer_cast(der_blocks[0]); auto non_stiff_block = der_block->get_statement_block()->clone(); - { - auto conserve_equations = collect_nodes(*non_stiff_block, {ast::AstNodeType::CONSERVE}); - if (!conserve_equations.empty()) { - std::unordered_set eqs; - for (const auto& item: conserve_equations) { - eqs.insert(std::dynamic_pointer_cast(item).get()); - } - non_stiff_block->erase_statement(eqs); - } - } + remove_conserve_statements(*non_stiff_block); auto stiff_block = der_block->get_statement_block()->clone(); - { - auto conserve_equations = collect_nodes(*stiff_block, {ast::AstNodeType::CONSERVE}); - if (!conserve_equations.empty()) { - std::unordered_set eqs; - for (const auto& item: conserve_equations) { - eqs.insert(std::dynamic_pointer_cast(item).get()); - } - stiff_block->erase_statement(eqs); - } - } - + remove_conserve_statements(*stiff_block); NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); From a60577b9957944567240b91b70f11cd8af5387fd Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 15 Oct 2024 16:27:35 +0200 Subject: [PATCH 47/74] Remove unused class field --- src/visitors/cvode_visitor.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index e93b2b3fd0..c4c7e08248 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -98,7 +98,6 @@ class CvodeHelperVisitor: public AstVisitor { protected: symtab::SymbolTable* program_symtab = nullptr; bool in_differential_equation = false; - std::unordered_set conserve_equations; public: inline void visit_diff_eq_expression(ast::DiffEqExpression& node) { in_differential_equation = true; From 8bc7d18e00344cdaecbae0936fca3b6edd7129d8 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 16 Oct 2024 15:29:37 +0200 Subject: [PATCH 48/74] Only enable sympy if DERIVATIVE block exists --- src/main.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main.cpp b/src/main.cpp index 89c2813f23..b63b7b0c70 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -517,13 +517,15 @@ int run_nmodl(int argc, const char* argv[]) { enable_sympy(solver_exists(*ast, "derivimplicit"), "'SOLVE ... METHOD derivimplicit'"); enable_sympy(node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK), "'LINEAR' block"); + enable_sympy(node_exists(*ast, ast::AstNodeType::DERIVATIVE_BLOCK), + "'DERIVATIVE' block"); enable_sympy(node_exists(*ast, ast::AstNodeType::NON_LINEAR_BLOCK), "'NONLINEAR' block"); enable_sympy(solver_exists(*ast, "sparse"), "'SOLVE ... METHOD sparse'"); } - if (sympy_conductance || sympy_analytic || neuron_code) { + if (sympy_conductance || sympy_analytic) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); From 40cb10bd7893c2f01da4e74393f4ffbfdd27189c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 16 Oct 2024 15:46:20 +0200 Subject: [PATCH 49/74] Rename CVODE subblocks with more apt names --- src/language/codegen.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 67c9efe214..73f113c166 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -96,11 +96,11 @@ type: Name node_name: true suffix: {value: " "} - - function_block: - brief: "Block with statements of the form Dvar = f(var)" + - nonstiff_block: + brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems" type: StatementBlock - - diagonal_jacobian_block: - brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f))" + - stiff_block: + brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f)), used for updating stiff systems" type: StatementBlock brief: "Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks" - LongitudinalDiffusionBlock: From b29c277e2298737674b58bdd898f0e88abe9a821 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 16 Oct 2024 15:48:54 +0200 Subject: [PATCH 50/74] Rename function calls --- src/codegen/codegen_neuron_cpp_visitor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 497e2960da..a6fcfb2774 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2559,7 +2559,7 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); if (info.cvode_block) { - auto block = info.cvode_block->get_function_block(); + auto block = info.cvode_block->get_stiff_block(); print_statement_block(*block, false, false); } @@ -2623,7 +2623,7 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { get_parameter_str(cvode_update_parameters()))); // begin function definition if (info.cvode_block) { - auto block = info.cvode_block->get_diagonal_jacobian_block(); + auto block = info.cvode_block->get_non_stiff_block(); print_statement_block(*block, false, false); } From 9b58feb2ff9b47ea532f21e7c204e6f9f1d771c7 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 16 Oct 2024 15:49:31 +0200 Subject: [PATCH 51/74] `nonstiff` -> `non_stiff` --- src/language/codegen.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 73f113c166..4659965be9 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -96,7 +96,7 @@ type: Name node_name: true suffix: {value: " "} - - nonstiff_block: + - non_stiff_block: brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems" type: StatementBlock - stiff_block: From 052a8b7f2b6448728e7615dacfa5609f5f80a665 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 11:13:47 +0200 Subject: [PATCH 52/74] Fix typo in codegen --- src/codegen/codegen_neuron_cpp_visitor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index a6fcfb2774..f409e6de19 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2559,7 +2559,7 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); if (info.cvode_block) { - auto block = info.cvode_block->get_stiff_block(); + auto block = info.cvode_block->get_non_stiff_block(); print_statement_block(*block, false, false); } @@ -2623,7 +2623,7 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { get_parameter_str(cvode_update_parameters()))); // begin function definition if (info.cvode_block) { - auto block = info.cvode_block->get_non_stiff_block(); + auto block = info.cvode_block->get_stiff_block(); print_statement_block(*block, false, false); } From 6bee9e035380a65497468d75a31251f224b43891 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 14:15:03 +0200 Subject: [PATCH 53/74] Fix issues with wrong # of ODEs --- src/codegen/codegen_helper_visitor.cpp | 2 +- src/language/codegen.yaml | 5 +++++ src/visitors/cvode_visitor.cpp | 10 ++++++---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index 6915ff11c4..da0c92f40c 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -632,6 +632,7 @@ void CodegenHelperVisitor::visit_nrn_state_block(const ast::NrnStateBlock& node) void CodegenHelperVisitor::visit_cvode_block(const ast::CvodeBlock& node) { info.cvode_block = &node; + info.num_equations = node.get_n_odes()->get_value(); node.visit_children(*this); } @@ -745,7 +746,6 @@ void CodegenHelperVisitor::visit_statement_block(const ast::StatementBlock& node return sym->get_name() == symbol->get_name(); }) == info.prime_variables_by_order.end()) { info.prime_variables_by_order.push_back(symbol); - info.num_equations++; } } } diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 4659965be9..c7ce97c5b8 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -96,6 +96,11 @@ type: Name node_name: true suffix: {value: " "} + - n_odes: + brief: "number of ODEs to solve" + type: Integer + prefix: {value: "["} + suffix: {value: "]"} - non_stiff_block: brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems" type: StatementBlock diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index c4c7e08248..54dc3ca721 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -187,10 +187,12 @@ void CvodeVisitor::visit_program(ast::Program& node) { NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); - node.emplace_back_node( - new ast::CvodeBlock(der_block->get_name(), - std::shared_ptr(non_stiff_block), - std::shared_ptr(stiff_block))); + auto prime_vars = collect_nodes(*der_block, {ast::AstNodeType::PRIME_NAME}); + node.emplace_back_node(new ast::CvodeBlock( + der_block->get_name(), + std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), + std::shared_ptr(non_stiff_block), + std::shared_ptr(stiff_block))); } } From af1fbf08efd37821a25c5e5ad6d60e6a2dba9af6 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 14:26:51 +0200 Subject: [PATCH 54/74] Revert back part of master for compatibility --- src/codegen/codegen_helper_visitor.cpp | 2 +- src/codegen/codegen_neuron_cpp_visitor.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index da0c92f40c..6915ff11c4 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -632,7 +632,6 @@ void CodegenHelperVisitor::visit_nrn_state_block(const ast::NrnStateBlock& node) void CodegenHelperVisitor::visit_cvode_block(const ast::CvodeBlock& node) { info.cvode_block = &node; - info.num_equations = node.get_n_odes()->get_value(); node.visit_children(*this); } @@ -746,6 +745,7 @@ void CodegenHelperVisitor::visit_statement_block(const ast::StatementBlock& node return sym->get_name() == symbol->get_name(); }) == info.prime_variables_by_order.end()) { info.prime_variables_by_order.push_back(symbol); + info.num_equations++; } } } diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index f409e6de19..4cb51f9020 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2544,7 +2544,7 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { /* return # of ODEs to solve */ printer->push_block( fmt::format("static constexpr int ode_count_{}(int _type)", info.mod_suffix)); - printer->fmt_line("return {};", info.num_equations); + printer->fmt_line("return {};", info.cvode_block->get_n_odes()->get_value()); printer->pop_block(); printer->add_newline(2); From a6ea5abaa979d70b0f3b850e1829d1754e027347 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 14:28:44 +0200 Subject: [PATCH 55/74] Get # of ODEs to solve --- src/language/codegen.yaml | 5 +++++ src/visitors/cvode_visitor.cpp | 10 ++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 4659965be9..c7ce97c5b8 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -96,6 +96,11 @@ type: Name node_name: true suffix: {value: " "} + - n_odes: + brief: "number of ODEs to solve" + type: Integer + prefix: {value: "["} + suffix: {value: "]"} - non_stiff_block: brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems" type: StatementBlock diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index c4c7e08248..54dc3ca721 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -187,10 +187,12 @@ void CvodeVisitor::visit_program(ast::Program& node) { NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); - node.emplace_back_node( - new ast::CvodeBlock(der_block->get_name(), - std::shared_ptr(non_stiff_block), - std::shared_ptr(stiff_block))); + auto prime_vars = collect_nodes(*der_block, {ast::AstNodeType::PRIME_NAME}); + node.emplace_back_node(new ast::CvodeBlock( + der_block->get_name(), + std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), + std::shared_ptr(non_stiff_block), + std::shared_ptr(stiff_block))); } } From 494dbe766112424c4d5e16d0d62ad2125396ac26 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 14:29:39 +0200 Subject: [PATCH 56/74] Reorder tests --- test/usecases/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/usecases/CMakeLists.txt b/test/usecases/CMakeLists.txt index c1c0688c84..8efd4372e9 100644 --- a/test/usecases/CMakeLists.txt +++ b/test/usecases/CMakeLists.txt @@ -3,6 +3,7 @@ set(NMODL_USECASE_DIRS builtin_functions constant constructor + cvode electrode_current external function @@ -30,8 +31,7 @@ set(NMODL_USECASE_DIRS steady_state suffix table - useion - cvode) + useion) foreach(usecase ${NMODL_USECASE_DIRS}) add_test(NAME usecase_${usecase} From 5d94f7a88499d3f6c0fa35cfba407ec98314f6ee Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 16:02:11 +0200 Subject: [PATCH 57/74] der_block(s) -> derivative_block(s) --- src/visitors/cvode_visitor.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 54dc3ca721..41845ead33 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -175,21 +175,22 @@ class StiffVisitor: public CvodeHelperVisitor { void CvodeVisitor::visit_program(ast::Program& node) { - auto der_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); - if (!der_blocks.empty()) { - auto der_block = std::dynamic_pointer_cast(der_blocks[0]); + auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); + if (!derivative_blocks.empty()) { + auto derivative_block = std::dynamic_pointer_cast( + derivative_blocks[0]); - auto non_stiff_block = der_block->get_statement_block()->clone(); + auto non_stiff_block = derivative_block->get_statement_block()->clone(); remove_conserve_statements(*non_stiff_block); - auto stiff_block = der_block->get_statement_block()->clone(); + auto stiff_block = derivative_block->get_statement_block()->clone(); remove_conserve_statements(*stiff_block); NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); - auto prime_vars = collect_nodes(*der_block, {ast::AstNodeType::PRIME_NAME}); + auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME}); node.emplace_back_node(new ast::CvodeBlock( - der_block->get_name(), + derivative_block->get_name(), std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), std::shared_ptr(non_stiff_block), std::shared_ptr(stiff_block))); From bf9db7acef87ab100583dcd2de83ab3c666469ca Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 16:30:25 +0200 Subject: [PATCH 58/74] get_name_map -> get_indexed_variables Also use a set since we don't care about the actual index for the RHS --- src/pybind/wrapper.cpp | 4 ++-- src/pybind/wrapper.hpp | 8 ++++---- src/visitors/cvode_visitor.cpp | 33 ++++++++++++++++----------------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 8385954ed7..d59b579d97 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -190,10 +190,10 @@ except Exception as e: std::tuple call_diff2c( const std::string& expression, const std::pair>& variable, - const std::unordered_map& indexed_vars) { + const std::unordered_set& indexed_vars) { std::string statements; // only indexed variables require special treatment - for (const auto& [var, prop]: indexed_vars) { + for (const auto& var: indexed_vars) { statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var); } auto [name, property] = variable; diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index aad85aef25..e93cca51f5 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include namespace nmodl { @@ -50,12 +50,12 @@ std::tuple call_analytic_diff( /// \brief Differentiates an expression with respect to a variable /// \param expression The expression we want to differentiate /// \param variable The name of the independent variable we are differentiating against -/// \param index_vars A map of array (indexable) variables (and their associated indices) that -/// appear in \ref expression \return The tuple (solution, exception) +/// \param index_vars A set of array (indexable) variables that appear in \ref expression +/// \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, const std::pair>& variable, - const std::unordered_map& indexed_vars = {}); + const std::unordered_set& indexed_vars = {}); struct pybind_wrap_api { decltype(&initialize_interpreter_func) initialize_interpreter; diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 41845ead33..04eea17093 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -45,27 +45,26 @@ static std::pair> parse_independent_var( return variable; } -static std::unordered_map get_name_map(const ast::Expression& node, - const std::string& name) { - std::unordered_map name_map; - // all of the "reserved" symbols +/// set of all indexed variables not equal to ``name`` +static std::unordered_set get_indexed_variables(const ast::Expression& node, + const std::string& name) { + std::unordered_set indexed_variables; + // all of the "reserved" vars auto reserved_symbols = get_external_functions(); // all indexed vars auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); for (const auto& var: indexed_vars) { - if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && - std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { - return var->get_node_name() == item; - })) { - logger->debug( - "CvodeVisitor :: adding INDEXED_VARIABLE {} to " - "node_map", - var->get_node_name()); - name_map[var->get_node_name()] = get_index( - *std::dynamic_pointer_cast(var)); + const auto& varname = var->get_node_name(); + // skip if it's a reserved var + auto varname_not_reserved = + std::none_of(reserved_symbols.begin(), + reserved_symbols.end(), + [&varname](const auto item) { return varname == item; }); + if (indexed_variables.count(varname) == 0 && varname != name && varname_not_reserved) { + indexed_variables.insert(varname); } } - return name_map; + return indexed_variables; } static std::string cvode_set_lhs(ast::BinaryExpression& node) { @@ -153,8 +152,8 @@ class StiffVisitor: public CvodeHelperVisitor { } auto rhs = node.get_rhs(); - // map of all indexed symbols (need special treatment in SymPy) - auto name_map = get_name_map(*rhs, name->get_node_name()); + // all indexed variables (need special treatment in SymPy) + auto name_map = get_indexed_variables(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; auto [jacobian, exception_message] = diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); From fd0c619c580e1fcf2e38d3ec836b64bf756eeedb Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 23:03:20 +0200 Subject: [PATCH 59/74] Add check for multiple DERIVATIVE blocks --- src/visitors/cvode_visitor.cpp | 54 ++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 04eea17093..75e20c43dc 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -175,25 +175,43 @@ class StiffVisitor: public CvodeHelperVisitor { void CvodeVisitor::visit_program(ast::Program& node) { auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); - if (!derivative_blocks.empty()) { - auto derivative_block = std::dynamic_pointer_cast( - derivative_blocks[0]); - - auto non_stiff_block = derivative_block->get_statement_block()->clone(); - remove_conserve_statements(*non_stiff_block); - - auto stiff_block = derivative_block->get_statement_block()->clone(); - remove_conserve_statements(*stiff_block); - - NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); - StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); - auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME}); - node.emplace_back_node(new ast::CvodeBlock( - derivative_block->get_name(), - std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), - std::shared_ptr(non_stiff_block), - std::shared_ptr(stiff_block))); + if (derivative_blocks.empty()) { + return; } + + // steady state adds a DERIVATIVE block with a `_steadystate` suffix + auto not_steadystate = [](const auto& item) { + auto name = std::dynamic_pointer_cast(item)->get_node_name(); + return !stringutils::ends_with(name, "_steadystate"); + }; + decltype(derivative_blocks) derivative_blocks_copy; + std::copy_if(derivative_blocks.begin(), + derivative_blocks.end(), + std::back_inserter(derivative_blocks_copy), + not_steadystate); + if (derivative_blocks_copy.size() > 1) { + auto message = "CvodeVisitor :: cannot have multiple DERIVATIVE blocks"; + logger->error(message); + throw std::runtime_error(message); + } + + auto derivative_block = std::dynamic_pointer_cast( + derivative_blocks_copy[0]); + + auto non_stiff_block = derivative_block->get_statement_block()->clone(); + remove_conserve_statements(*non_stiff_block); + + auto stiff_block = derivative_block->get_statement_block()->clone(); + remove_conserve_statements(*stiff_block); + + NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); + StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); + auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME}); + node.emplace_back_node(new ast::CvodeBlock( + derivative_block->get_name(), + std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), + std::shared_ptr(non_stiff_block), + std::shared_ptr(stiff_block))); } } // namespace visitor From 5f4a00aba712faeb6ec2535a887bc328ec799fb1 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 23:03:35 +0200 Subject: [PATCH 60/74] Update tests for CVODE --- test/unit/visitor/cvode.cpp | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index e42df2ec45..96d09549ab 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -28,8 +28,16 @@ auto run_cvode_visitor(const std::string& text) { TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { + GIVEN("No DERIVATIVE block") { + auto nmodl_text = "NEURON { SUFFIX example }"; + auto ast = run_cvode_visitor(nmodl_text); + THEN("No CVODE block is added") { + auto blocks = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); + REQUIRE(blocks.empty()); + } + } GIVEN("DERIVATIVE block") { - std::string nmodl_text = R"( + auto nmodl_text = R"( NEURON { SUFFIX example } @@ -58,4 +66,24 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { } } } + GIVEN("Multiple DERIVATIVE blocks") { + auto nmodl_text = R"( + NEURON { + SUFFIX example + } + + STATE {X} + + DERIVATIVE equation { + X' = -X + } + + DERIVATIVE equation2 { + X' = -X * X + } +)"; + THEN("An error is raised") { + REQUIRE_THROWS_AS(run_cvode_visitor(nmodl_text), std::runtime_error); + } + } } From ea07cbce810a06ead6d1d2498b5dae1516e8b965 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 23:28:36 +0200 Subject: [PATCH 61/74] Update docs --- docs/contents/cvode.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst index 367e160d48..6179d824c2 100644 --- a/docs/contents/cvode.rst +++ b/docs/contents/cvode.rst @@ -37,20 +37,20 @@ the structure of the ``CVODE`` block is then roughly: .. code-block:: - CVODE state { + CVODE state[n] { Dx_i = f_i(x_1, ..., x_n) }{ Dx_i = Dx_i / (1 - dt * J_ii(f)) } -where ``J_ii(f)`` is the diagonal part of the Jacobian, i.e. +where ``N`` is the total number of ODEs to solve, and ``J_ii(f)`` is the +diagonal part of the Jacobian, i.e. .. math:: J_{ii}(f) = \frac{ \partial f_i(x_1, \ldots, x_n) }{\partial x_i} -As an example, consider the following ``DERIVATIVE`` -block: +As an example, consider the following ``DERIVATIVE`` block: .. code-block:: @@ -63,7 +63,7 @@ Where ``X`` is a ``STATE`` variable with some initial value, specified in the .. code-block:: - CVODE state { + CVODE state[1] { DX = - X }{ DX = DX / (1 - dt * (-1)) From e30c27cbfbf68d41253fad467d002f23bad97e59 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 23:28:36 +0200 Subject: [PATCH 62/74] Update docs --- docs/contents/cvode.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst index 367e160d48..6179d824c2 100644 --- a/docs/contents/cvode.rst +++ b/docs/contents/cvode.rst @@ -37,20 +37,20 @@ the structure of the ``CVODE`` block is then roughly: .. code-block:: - CVODE state { + CVODE state[n] { Dx_i = f_i(x_1, ..., x_n) }{ Dx_i = Dx_i / (1 - dt * J_ii(f)) } -where ``J_ii(f)`` is the diagonal part of the Jacobian, i.e. +where ``N`` is the total number of ODEs to solve, and ``J_ii(f)`` is the +diagonal part of the Jacobian, i.e. .. math:: J_{ii}(f) = \frac{ \partial f_i(x_1, \ldots, x_n) }{\partial x_i} -As an example, consider the following ``DERIVATIVE`` -block: +As an example, consider the following ``DERIVATIVE`` block: .. code-block:: @@ -63,7 +63,7 @@ Where ``X`` is a ``STATE`` variable with some initial value, specified in the .. code-block:: - CVODE state { + CVODE state[1] { DX = - X }{ DX = DX / (1 - dt * (-1)) From 80dbe900f1475cb9432d111298cd268696fc4a3b Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 22 Oct 2024 12:18:32 +0200 Subject: [PATCH 63/74] Better naming --- src/codegen/codegen_naming.hpp | 18 ++++ src/codegen/codegen_neuron_cpp_visitor.cpp | 106 ++++++++++----------- 2 files changed, 69 insertions(+), 55 deletions(-) diff --git a/src/codegen/codegen_naming.hpp b/src/codegen/codegen_naming.hpp index cd0403c223..620e5d3592 100644 --- a/src/codegen/codegen_naming.hpp +++ b/src/codegen/codegen_naming.hpp @@ -182,6 +182,24 @@ static constexpr char THREAD_ARGS_PROTO[] = "_threadargsproto_"; /// prefix for ion variable static constexpr char ION_VARNAME_PREFIX[] = "ion_"; +/// name of CVODE method for counting # of ODEs +static constexpr char CVODE_COUNT_NAME[] = "ode_count"; + +/// name of CVODE method for updating non-stiff systems +static constexpr char CVODE_UPDATE_NON_STIFF_NAME[] = "ode_update_nonstiff"; + +/// name of CVODE method for updating stiff systems +static constexpr char CVODE_UPDATE_STIFF_NAME[] = "ode_update_stiff"; + +/// name of CVODE method for setting up non-stiff systems +static constexpr char CVODE_SETUP_NON_STIFF_NAME[] = "ode_setup_nonstiff"; + +/// name of CVODE method for setting up stiff systems +static constexpr char CVODE_SETUP_STIFF_NAME[] = "ode_setup_stiff"; + +/// name of CVODE method for setting up tolerances +static constexpr char CVODE_SETUP_TOLERANCES_NAME[] = "ode_setup_tolerances"; + /// commonly used variables in verbatim block and how they /// should be mapped to new code generation backends // clang-format off diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 4cb51f9020..6f4b3aa1b8 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -1443,13 +1443,15 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { if (info.emit_cvode) { printer->fmt_line("hoc_register_dparam_semantics(mech_type, {}, \"cvodeieq\");", codegen_int_variables_size); - printer->fmt_line( - "hoc_register_cvode(mech_type, ode_count_{}, ode_setup_tolerance_{}, " - "ode_setup_nonstiff_{}, ode_setup_stiff_{});", - info.mod_suffix, - info.mod_suffix, - info.mod_suffix, - info.mod_suffix); + printer->fmt_line("hoc_register_cvode(mech_type, {}_{}, {}_{}, {}_{}, {}_{});", + naming::CVODE_COUNT_NAME, + info.mod_suffix, + naming::CVODE_SETUP_TOLERANCES_NAME, + info.mod_suffix, + naming::CVODE_SETUP_NON_STIFF_NAME, + info.mod_suffix, + naming::CVODE_SETUP_STIFF_NAME, + info.mod_suffix); printer->fmt_line("hoc_register_tolerance(mech_type, _hoc_state_tol, &_atollist);"); } @@ -2542,39 +2544,33 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_line("/* Functions related to CVODE codegen */"); /* return # of ODEs to solve */ - printer->push_block( - fmt::format("static constexpr int ode_count_{}(int _type)", info.mod_suffix)); + printer->fmt_push_block("static constexpr int {}_{}(int _type)", + naming::CVODE_COUNT_NAME, + info.mod_suffix); printer->fmt_line("return {};", info.cvode_block->get_n_odes()->get_value()); printer->pop_block(); printer->add_newline(2); - auto update_nonstiff_name = fmt::format("ode_update_nonstiff_{}", info.mod_suffix); - /* The update function for non-stiff systems */ - printer->fmt_push_block("static int {}({})", - update_nonstiff_name, - get_parameter_str(cvode_update_parameters())); // begin function - // definition + printer->fmt_push_block("static int {}_{}({})", + naming::CVODE_UPDATE_NON_STIFF_NAME, + info.mod_suffix, + get_parameter_str(cvode_update_parameters())); // begin fn printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); - if (info.cvode_block) { - auto block = info.cvode_block->get_non_stiff_block(); - print_statement_block(*block, false, false); - } + print_statement_block(*info.cvode_block->get_non_stiff_block(), false, false); printer->add_line("return 0;"); - printer->pop_block(); // end function definition + printer->pop_block(); // end fn printer->add_newline(2); - auto setup_nonstiff_name = fmt::format("ode_setup_nonstiff_{}", info.mod_suffix); - /* The setup function for non-stiff systems */ - printer->push_block( - fmt::format("static void {}({})", - setup_nonstiff_name, - get_parameter_str(cvode_setup_parameters()))); // begin function definition + printer->fmt_push_block("static void {}_{}({})", + naming::CVODE_SETUP_NON_STIFF_NAME, + info.mod_suffix, + get_parameter_str(cvode_setup_parameters())); // begin fn printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); printer->add_line("auto nodecount = _ml_arg->nodecount;"); @@ -2589,55 +2585,52 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); - printer->fmt_line("{}({});", update_nonstiff_name, get_arg_str(cvode_update_parameters())); + printer->fmt_line("{}_{}({});", + naming::CVODE_UPDATE_NON_STIFF_NAME, + info.mod_suffix, + get_arg_str(cvode_update_parameters())); printer->pop_block(); // end for loop - printer->pop_block(); // end function definition + printer->pop_block(); // end fn printer->add_newline(2); /* The function for setup of tolerance */ - printer->push_block( - fmt::format("static void ode_setup_tolerance_{}(Prop* _prop, int equation_index, " - "neuron::container::data_handle* _pv, " - "neuron::container::data_handle* _pvdot, double* _atol, int _type)", - info.mod_suffix)); // begin function definition + printer->fmt_push_block( + "static void {}_{}(Prop* _prop, int equation_index, " + "neuron::container::data_handle* _pv, " + "neuron::container::data_handle* _pvdot, double* _atol, int _type)", + naming::CVODE_SETUP_TOLERANCES_NAME, + info.mod_suffix); // begin fn printer->add_line("auto* _ppvar = _nrn_mechanism_access_dparam(_prop);"); printer->fmt_line("_ppvar[{}].literal_value() = equation_index;", int_variables_size()); - printer->push_block(fmt::format("for (int i = 0; i < ode_count_{}(0); i++)", - info.mod_suffix)); // begin for loop + printer->fmt_push_block("for (int i = 0; i < ode_count_{}(0); i++)", + info.mod_suffix); // begin for loop printer->add_line("_pv[i] = _nrn_mechanism_get_param_handle(_prop, _slist1[i]);"); printer->add_line("_pvdot[i] = _nrn_mechanism_get_param_handle(_prop, _dlist1[i]);"); printer->add_line("_cvode_abstol(_atollist, _atol, i);"); printer->pop_block(); // end for loop - printer->pop_block(); // end function definition + printer->pop_block(); // end fn printer->add_newline(2); - auto update_stiff_name = fmt::format("ode_update_stiff_{}", info.mod_suffix); - /* The update function for stiff systems */ - printer->push_block( - fmt::format("static void {}({})", - update_stiff_name, - get_parameter_str(cvode_update_parameters()))); // begin function definition + printer->fmt_push_block("static void {}_{}({})", + naming::CVODE_UPDATE_STIFF_NAME, + info.mod_suffix, + get_parameter_str(cvode_update_parameters())); // begin fn - if (info.cvode_block) { - auto block = info.cvode_block->get_stiff_block(); - print_statement_block(*block, false, false); - } + print_statement_block(*info.cvode_block->get_stiff_block(), false, false); - printer->pop_block(); // end function definition + printer->pop_block(); // end fn printer->add_newline(2); - auto setup_stiff_name = fmt::format("ode_setup_stiff_{}", info.mod_suffix); - /* The setup function for stiff systems */ - printer->push_block( - fmt::format("static void {}({})", - setup_stiff_name, - get_parameter_str(cvode_setup_parameters()))); // begin function definition + printer->fmt_push_block("static void {}_{}({})", + naming::CVODE_SETUP_STIFF_NAME, + info.mod_suffix, + get_parameter_str(cvode_setup_parameters())); // begin fn printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); printer->add_line("auto nodecount = _ml_arg->nodecount;"); @@ -2654,10 +2647,13 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); - printer->fmt_line("{}({});", update_stiff_name, get_arg_str(cvode_update_parameters())); + printer->fmt_line("{}_{}({});", + naming::CVODE_UPDATE_STIFF_NAME, + info.mod_suffix, + get_arg_str(cvode_update_parameters())); printer->pop_block(); // end for loop - printer->pop_block(); // end function definition + printer->pop_block(); // end fn } void CodegenNeuronCppVisitor::print_net_receive() { From ddeef2b38305b1f2bd09fe147589e9d66374dc6b Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 22 Oct 2024 12:41:41 +0200 Subject: [PATCH 64/74] Fix for #1529 --- src/codegen/codegen_neuron_cpp_visitor.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 47f1401ce5..5062f2296f 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2613,8 +2613,10 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { naming::CVODE_UPDATE_NON_STIFF_NAME, info.mod_suffix, get_parameter_str(cvode_update_parameters())); // begin fn - printer->add_line("int node_id = node_data.nodeindices[id];"); - printer->add_line("auto v = node_data.node_voltages[node_id];"); + printer->add_line( + "auto v = node_data.node_voltages ? " + "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); + print_statement_block(*info.cvode_block->get_non_stiff_block(), false, false); printer->add_line("return 0;"); @@ -2628,7 +2630,7 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { info.mod_suffix, get_parameter_str(cvode_setup_parameters())); // begin fn printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + printer->fmt_line("auto inst = make_instance_{}(&_lmc);", info.mod_suffix); printer->add_line("auto nodecount = _ml_arg->nodecount;"); printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); printer->add_line("auto* _thread = _ml_arg->_thread;"); @@ -2638,9 +2640,11 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { info.thread_var_thread_id); } printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for loop - printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); - printer->add_line("auto v = node_data.node_voltages[node_id];"); + printer->add_line( + "auto v = node_data.node_voltages ? " + "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); + printer->fmt_line("{}_{}({});", naming::CVODE_UPDATE_NON_STIFF_NAME, info.mod_suffix, @@ -2688,7 +2692,7 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { info.mod_suffix, get_parameter_str(cvode_setup_parameters())); // begin fn printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + printer->fmt_line("auto inst = make_instance_{}(&_lmc);", info.mod_suffix); printer->add_line("auto nodecount = _ml_arg->nodecount;"); printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); printer->add_line("auto* _thread = _ml_arg->_thread;"); @@ -2700,9 +2704,10 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { } printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for loop - printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); - printer->add_line("auto v = node_data.node_voltages[node_id];"); + printer->add_line( + "auto v = node_data.node_voltages ? " + "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); printer->fmt_line("{}_{}({});", naming::CVODE_UPDATE_STIFF_NAME, info.mod_suffix, From e1f715bc0325c8e90f80e692aa8b77e804df7ad6 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 23 Oct 2024 12:43:27 +0200 Subject: [PATCH 65/74] Add voltage to `update_stiff` --- src/codegen/codegen_neuron_cpp_visitor.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 5062f2296f..6f11180ce3 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2680,6 +2680,10 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { info.mod_suffix, get_parameter_str(cvode_update_parameters())); // begin fn + printer->add_line( + "auto v = node_data.node_voltages ? " + "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); + print_statement_block(*info.cvode_block->get_stiff_block(), false, false); printer->pop_block(); // end fn From 167d38f03aa0ea90cd81d7da97ec950a5ca8aa95 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 24 Oct 2024 13:43:24 +0200 Subject: [PATCH 66/74] Address comments from review --- src/visitors/cvode_visitor.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 75e20c43dc..ea86b14bfa 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -153,10 +153,10 @@ class StiffVisitor: public CvodeHelperVisitor { auto rhs = node.get_rhs(); // all indexed variables (need special treatment in SymPy) - auto name_map = get_indexed_variables(*rhs, name->get_node_name()); + auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; - auto [jacobian, - exception_message] = diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); + auto [jacobian, exception_message] = + diff2c(to_nmodl(*rhs), parse_independent_var(name), indexed_variables); if (!exception_message.empty()) { logger->warn("CvodeVisitor :: python exception: {}", exception_message); } @@ -172,11 +172,10 @@ class StiffVisitor: public CvodeHelperVisitor { } }; - -void CvodeVisitor::visit_program(ast::Program& node) { +static std::shared_ptr get_derivative_block(ast::Program& node) { auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); if (derivative_blocks.empty()) { - return; + return nullptr; } // steady state adds a DERIVATIVE block with a `_steadystate` suffix @@ -195,8 +194,15 @@ void CvodeVisitor::visit_program(ast::Program& node) { throw std::runtime_error(message); } - auto derivative_block = std::dynamic_pointer_cast( - derivative_blocks_copy[0]); + return std::dynamic_pointer_cast(derivative_blocks_copy[0]); +} + + +void CvodeVisitor::visit_program(ast::Program& node) { + auto derivative_block = get_derivative_block(node); + if (derivative_block == nullptr) { + return; + } auto non_stiff_block = derivative_block->get_statement_block()->clone(); remove_conserve_statements(*non_stiff_block); From de609a8ebba455731c16b20ea0e09f500c8557ac Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 28 Oct 2024 12:43:37 +0100 Subject: [PATCH 67/74] Codegen --- src/codegen/codegen_neuron_cpp_visitor.cpp | 186 ++++++++++++--------- src/codegen/codegen_neuron_cpp_visitor.hpp | 14 ++ 2 files changed, 121 insertions(+), 79 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index bab36908d2..9e9acbe92b 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -1507,15 +1507,11 @@ void CodegenNeuronCppVisitor::print_mechanism_register_regular() { if (info.emit_cvode) { printer->fmt_line("hoc_register_dparam_semantics(mech_type, {}, \"cvodeieq\");", codegen_int_variables_size); - printer->fmt_line("hoc_register_cvode(mech_type, {}_{}, {}_{}, {}_{}, {}_{});", - naming::CVODE_COUNT_NAME, - info.mod_suffix, - naming::CVODE_SETUP_TOLERANCES_NAME, - info.mod_suffix, - naming::CVODE_SETUP_NON_STIFF_NAME, - info.mod_suffix, - naming::CVODE_SETUP_STIFF_NAME, - info.mod_suffix); + printer->fmt_line("hoc_register_cvode(mech_type, {}, {}, {}, {});", + method_name(naming::CVODE_COUNT_NAME), + method_name(naming::CVODE_SETUP_TOLERANCES_NAME), + method_name(naming::CVODE_SETUP_NON_STIFF_NAME), + method_name(naming::CVODE_SETUP_STIFF_NAME)); printer->fmt_line("hoc_register_tolerance(mech_type, _hoc_state_tol, &_atollist);"); } } @@ -2644,27 +2640,63 @@ CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::cvode_update_parameters( return args; } -void CodegenNeuronCppVisitor::print_cvode_definitions() { - if (!info.emit_cvode) { - return; - } - printer->add_newline(2); - printer->add_line("/* Functions related to CVODE codegen */"); - /* return # of ODEs to solve */ - printer->fmt_push_block("static constexpr int {}_{}(int _type)", - naming::CVODE_COUNT_NAME, - info.mod_suffix); +std::string CodegenNeuronCppVisitor::method_name(const std::string& name) { + return fmt::format("{}_{}", name, info.mod_suffix); +} + + +/* print the function returning the # of ODEs to solve */ +void CodegenNeuronCppVisitor::print_cvode_count() { + printer->fmt_push_block("static constexpr int {}(int _type)", + method_name(naming::CVODE_COUNT_NAME)); printer->fmt_line("return {};", info.cvode_block->get_n_odes()->get_value()); printer->pop_block(); + printer->add_newline(2); +} + +/* print the function for setup of tolerances */ +void CodegenNeuronCppVisitor::print_cvode_tolerances() { + CodegenNeuronCppVisitor::ParamVector tolerances_parameters = { + {"", "Prop*", "", "_prop"}, + {"", "int", "", "equation_index"}, + {"", "neuron::container::data_handle*", "", "_pv"}, + {"", "neuron::container::data_handle*", "", "_pvdot"}, + {"", "double*", "", "_atol"}, + {"", "int", "", "_type"}}; + + auto get_param_name = [](const auto& item) { return std::get<3>(item); }; + + auto prop_name = get_param_name(tolerances_parameters[0]); + auto eqindex_name = get_param_name(tolerances_parameters[1]); + auto pv_name = get_param_name(tolerances_parameters[2]); + auto pvdot_name = get_param_name(tolerances_parameters[3]); + auto atol_name = get_param_name(tolerances_parameters[4]); + printer->fmt_push_block("static void {}({})", + method_name(naming::CVODE_SETUP_TOLERANCES_NAME), + get_parameter_str(tolerances_parameters)); + printer->fmt_line("auto* _ppvar = _nrn_mechanism_access_dparam({});", prop_name); + printer->fmt_line("_ppvar[{}].literal_value() = {};", int_variables_size(), eqindex_name); + printer->fmt_push_block("for (int i = 0; i < {}(0); i++)", + method_name(naming::CVODE_COUNT_NAME)); + printer->fmt_line("{}[i] = _nrn_mechanism_get_param_handle({}, _slist1[i]);", + pv_name, + prop_name); + printer->fmt_line("{}[i] = _nrn_mechanism_get_param_handle({}, _dlist1[i]);", + pvdot_name, + prop_name); + printer->fmt_line("_cvode_abstol(_atollist, {}, i);", atol_name); + printer->pop_block(); + printer->pop_block(); printer->add_newline(2); +} - /* The update function for non-stiff systems */ - printer->fmt_push_block("static int {}_{}({})", - naming::CVODE_UPDATE_NON_STIFF_NAME, - info.mod_suffix, - get_parameter_str(cvode_update_parameters())); // begin fn +/* print the update function for non-stiff systems */ +void CodegenNeuronCppVisitor::print_cvode_non_stiff_update() { + printer->fmt_push_block("static int {}({})", + method_name(naming::CVODE_UPDATE_NON_STIFF_NAME), + get_parameter_str(cvode_update_parameters())); printer->add_line( "auto v = node_data.node_voltages ? " "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); @@ -2672,65 +2704,45 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { print_statement_block(*info.cvode_block->get_non_stiff_block(), false, false); printer->add_line("return 0;"); - printer->pop_block(); // end fn - + printer->pop_block(); printer->add_newline(2); +} - /* The setup function for non-stiff systems */ - printer->fmt_push_block("static void {}_{}({})", - naming::CVODE_SETUP_NON_STIFF_NAME, - info.mod_suffix, - get_parameter_str(cvode_setup_parameters())); // begin fn +/* print the setup function for non-stiff systems */ +void CodegenNeuronCppVisitor::print_cvode_non_stiff_setup() { + printer->fmt_push_block("static void {}({})", + method_name(naming::CVODE_SETUP_NON_STIFF_NAME), + get_parameter_str(cvode_setup_parameters())); printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - printer->fmt_line("auto inst = make_instance_{}(&_lmc);", info.mod_suffix); + printer->fmt_line("auto inst = {}(&_lmc);", method_name("make_instance")); printer->add_line("auto nodecount = _ml_arg->nodecount;"); - printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); + printer->fmt_line("auto node_data = {}(*nt, *_ml_arg);", method_name("make_node_data")); printer->add_line("auto* _thread = _ml_arg->_thread;"); if (!codegen_thread_variables.empty()) { printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", thread_variables_struct(), info.thread_var_thread_id); } - printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for loop + printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); printer->add_line( "auto v = node_data.node_voltages ? " "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); - printer->fmt_line("{}_{}({});", - naming::CVODE_UPDATE_NON_STIFF_NAME, - info.mod_suffix, + printer->fmt_line("{}({});", + method_name(naming::CVODE_UPDATE_NON_STIFF_NAME), get_arg_str(cvode_update_parameters())); - printer->pop_block(); // end for loop - printer->pop_block(); // end fn - - printer->add_newline(2); - - /* The function for setup of tolerance */ - printer->fmt_push_block( - "static void {}_{}(Prop* _prop, int equation_index, " - "neuron::container::data_handle* _pv, " - "neuron::container::data_handle* _pvdot, double* _atol, int _type)", - naming::CVODE_SETUP_TOLERANCES_NAME, - info.mod_suffix); // begin fn - printer->add_line("auto* _ppvar = _nrn_mechanism_access_dparam(_prop);"); - printer->fmt_line("_ppvar[{}].literal_value() = equation_index;", int_variables_size()); - printer->fmt_push_block("for (int i = 0; i < ode_count_{}(0); i++)", - info.mod_suffix); // begin for loop - printer->add_line("_pv[i] = _nrn_mechanism_get_param_handle(_prop, _slist1[i]);"); - printer->add_line("_pvdot[i] = _nrn_mechanism_get_param_handle(_prop, _dlist1[i]);"); - printer->add_line("_cvode_abstol(_atollist, _atol, i);"); - printer->pop_block(); // end for loop - printer->pop_block(); // end fn - + printer->pop_block(); + printer->pop_block(); printer->add_newline(2); +} - /* The update function for stiff systems */ - printer->fmt_push_block("static void {}_{}({})", - naming::CVODE_UPDATE_STIFF_NAME, - info.mod_suffix, - get_parameter_str(cvode_update_parameters())); // begin fn +/* print the update function for stiff systems */ +void CodegenNeuronCppVisitor::print_cvode_stiff_update() { + printer->fmt_push_block("static void {}({})", + method_name(naming::CVODE_UPDATE_STIFF_NAME), + get_parameter_str(cvode_update_parameters())); printer->add_line( "auto v = node_data.node_voltages ? " @@ -2738,19 +2750,19 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { print_statement_block(*info.cvode_block->get_stiff_block(), false, false); - printer->pop_block(); // end fn - + printer->pop_block(); printer->add_newline(2); +} - /* The setup function for stiff systems */ - printer->fmt_push_block("static void {}_{}({})", - naming::CVODE_SETUP_STIFF_NAME, - info.mod_suffix, - get_parameter_str(cvode_setup_parameters())); // begin fn +/* print the setup function for stiff systems */ +void CodegenNeuronCppVisitor::print_cvode_stiff_setup() { + printer->fmt_push_block("static void {}({})", + method_name(naming::CVODE_SETUP_STIFF_NAME), + get_parameter_str(cvode_setup_parameters())); printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - printer->fmt_line("auto inst = make_instance_{}(&_lmc);", info.mod_suffix); + printer->fmt_line("auto inst = {}(&_lmc);", method_name("make_instance")); printer->add_line("auto nodecount = _ml_arg->nodecount;"); - printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml_arg);", info.mod_suffix); + printer->fmt_line("auto node_data = {}(*nt, *_ml_arg);", method_name("make_node_data")); printer->add_line("auto* _thread = _ml_arg->_thread;"); if (!codegen_thread_variables.empty()) { @@ -2759,18 +2771,34 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { info.thread_var_thread_id); } - printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for loop + printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); printer->add_line( "auto v = node_data.node_voltages ? " "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); - printer->fmt_line("{}_{}({});", - naming::CVODE_UPDATE_STIFF_NAME, - info.mod_suffix, + printer->fmt_line("{}({});", + method_name(naming::CVODE_UPDATE_STIFF_NAME), get_arg_str(cvode_update_parameters())); - printer->pop_block(); // end for loop - printer->pop_block(); // end fn + printer->pop_block(); + printer->pop_block(); + + printer->add_newline(2); +} + +void CodegenNeuronCppVisitor::print_cvode_definitions() { + if (!info.emit_cvode) { + return; + } + + printer->add_newline(2); + printer->add_line("/* Functions related to CVODE codegen */"); + print_cvode_count(); + print_cvode_non_stiff_update(); + print_cvode_non_stiff_setup(); + print_cvode_tolerances(); + print_cvode_stiff_update(); + print_cvode_stiff_setup(); } void CodegenNeuronCppVisitor::print_net_receive() { diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index 92650ffc00..e9e08b1622 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -783,6 +783,20 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { */ void print_cvode_definitions(); + void print_cvode_count(); + + void print_cvode_tolerances(); + + void print_cvode_non_stiff_update(); + + void print_cvode_non_stiff_setup(); + + void print_cvode_stiff_update(); + + void print_cvode_stiff_setup(); + + std::string method_name(const std::string& name); + /****************************************************************************************/ /* Overloaded visitor routines */ From aa34bc96ca1ab2f31b17a5afd87f22db88fabcb2 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 28 Oct 2024 13:23:17 +0100 Subject: [PATCH 68/74] Refactor with less repetition --- src/codegen/codegen_neuron_cpp_visitor.cpp | 91 +++++----------------- src/codegen/codegen_neuron_cpp_visitor.hpp | 8 +- 2 files changed, 20 insertions(+), 79 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 24160bc696..1ed2a62958 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2661,7 +2661,7 @@ void CodegenNeuronCppVisitor::print_cvode_count() { /* print the function for setup of tolerances */ void CodegenNeuronCppVisitor::print_cvode_tolerances() { - CodegenNeuronCppVisitor::ParamVector tolerances_parameters = { + const CodegenNeuronCppVisitor::ParamVector tolerances_parameters = { {"", "Prop*", "", "_prop"}, {"", "int", "", "equation_index"}, {"", "neuron::container::data_handle*", "", "_pv"}, @@ -2696,93 +2696,37 @@ void CodegenNeuronCppVisitor::print_cvode_tolerances() { printer->add_newline(2); } -/* print the update function for non-stiff systems */ -void CodegenNeuronCppVisitor::print_cvode_non_stiff_update() { +/* print the CVODE update function (called ``name``) from ``block`` */ +void CodegenNeuronCppVisitor::print_cvode_update(const std::string& name, + const ast::StatementBlock& block) { printer->fmt_push_block("static int {}({})", - method_name(naming::CVODE_UPDATE_NON_STIFF_NAME), + method_name(name), get_parameter_str(cvode_update_parameters())); printer->add_line( "auto v = node_data.node_voltages ? " "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); - print_statement_block(*info.cvode_block->get_non_stiff_block(), false, false); + print_statement_block(block, false, false); printer->add_line("return 0;"); printer->pop_block(); printer->add_newline(2); } -/* print the setup function for non-stiff systems */ -void CodegenNeuronCppVisitor::print_cvode_non_stiff_setup() { +/* print the setup function (that calls an update function) */ +void CodegenNeuronCppVisitor::print_cvode_setup(const std::string& setup_name, + const std::string& update_name) { printer->fmt_push_block("static void {}({})", - method_name(naming::CVODE_SETUP_NON_STIFF_NAME), + method_name(setup_name), get_parameter_str(cvode_setup_parameters())); - printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - printer->fmt_line("auto inst = {}(&_lmc);", method_name("make_instance")); - printer->add_line("auto nodecount = _ml_arg->nodecount;"); - printer->fmt_line("auto node_data = {}(*nt, *_ml_arg);", method_name("make_node_data")); - printer->add_line("auto* _thread = _ml_arg->_thread;"); - if (!codegen_thread_variables.empty()) { - printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", - thread_variables_struct(), - info.thread_var_thread_id); - } - printer->push_block("for (int id = 0; id < nodecount; id++)"); - printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); - printer->add_line( - "auto v = node_data.node_voltages ? " - "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); - - printer->fmt_line("{}({});", - method_name(naming::CVODE_UPDATE_NON_STIFF_NAME), - get_arg_str(cvode_update_parameters())); - - printer->pop_block(); - printer->pop_block(); - printer->add_newline(2); -} - -/* print the update function for stiff systems */ -void CodegenNeuronCppVisitor::print_cvode_stiff_update() { - printer->fmt_push_block("static void {}({})", - method_name(naming::CVODE_UPDATE_STIFF_NAME), - get_parameter_str(cvode_update_parameters())); - - printer->add_line( - "auto v = node_data.node_voltages ? " - "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); - - print_statement_block(*info.cvode_block->get_stiff_block(), false, false); - - printer->pop_block(); - printer->add_newline(2); -} - -/* print the setup function for stiff systems */ -void CodegenNeuronCppVisitor::print_cvode_stiff_setup() { - printer->fmt_push_block("static void {}({})", - method_name(naming::CVODE_SETUP_STIFF_NAME), - get_parameter_str(cvode_setup_parameters())); - printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};"); - printer->fmt_line("auto inst = {}(&_lmc);", method_name("make_instance")); - printer->add_line("auto nodecount = _ml_arg->nodecount;"); - printer->fmt_line("auto node_data = {}(*nt, *_ml_arg);", method_name("make_node_data")); - printer->add_line("auto* _thread = _ml_arg->_thread;"); - - if (!codegen_thread_variables.empty()) { - printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", - thread_variables_struct(), - info.thread_var_thread_id); - } - + print_entrypoint_setup_code_from_memb_list(); + printer->fmt_line("auto nodecount = _ml_arg->nodecount;"); printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); printer->add_line( "auto v = node_data.node_voltages ? " "node_data.node_voltages[node_data.nodeindices[id]] : 0.0;"); - printer->fmt_line("{}({});", - method_name(naming::CVODE_UPDATE_STIFF_NAME), - get_arg_str(cvode_update_parameters())); + printer->fmt_line("{}({});", method_name(update_name), get_arg_str(cvode_update_parameters())); printer->pop_block(); printer->pop_block(); @@ -2798,11 +2742,12 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() { printer->add_newline(2); printer->add_line("/* Functions related to CVODE codegen */"); print_cvode_count(); - print_cvode_non_stiff_update(); - print_cvode_non_stiff_setup(); print_cvode_tolerances(); - print_cvode_stiff_update(); - print_cvode_stiff_setup(); + print_cvode_update(naming::CVODE_UPDATE_NON_STIFF_NAME, + *info.cvode_block->get_non_stiff_block()); + print_cvode_update(naming::CVODE_UPDATE_STIFF_NAME, *info.cvode_block->get_stiff_block()); + print_cvode_setup(naming::CVODE_SETUP_NON_STIFF_NAME, naming::CVODE_UPDATE_NON_STIFF_NAME); + print_cvode_setup(naming::CVODE_SETUP_STIFF_NAME, naming::CVODE_UPDATE_STIFF_NAME); } void CodegenNeuronCppVisitor::print_net_receive() { diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index 3059c66ec7..5619cf9713 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -788,13 +788,9 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_cvode_tolerances(); - void print_cvode_non_stiff_update(); + void print_cvode_update(const std::string& name, const ast::StatementBlock& node); - void print_cvode_non_stiff_setup(); - - void print_cvode_stiff_update(); - - void print_cvode_stiff_setup(); + void print_cvode_setup(const std::string& setup_name, const std::string& update_name); std::string method_name(const std::string& name); From 01c6514b63f99f976a93a3a03d34549b29a0af77 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 28 Oct 2024 13:27:43 +0100 Subject: [PATCH 69/74] Rename pytest test --- test/usecases/cvode/test_cvode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/usecases/cvode/test_cvode.py b/test/usecases/cvode/test_cvode.py index c5ae18cc3e..49f148b14a 100644 --- a/test/usecases/cvode/test_cvode.py +++ b/test/usecases/cvode/test_cvode.py @@ -4,7 +4,7 @@ from neuron.units import ms -def simulate(rtol): +def test_cvode(rtol): nseg = 1 mech = "scalar" @@ -49,4 +49,4 @@ def simulate(rtol): if __name__ == "__main__": - t, *x = simulate(rtol=1e-5) + t, *x = test_cvode(rtol=1e-5) From 361172c524aca908f2041f5431ad219478e4e151 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 28 Oct 2024 14:04:15 +0100 Subject: [PATCH 70/74] Set `cvode_ieq` as the variable name for CVODE Must have semantics of type `cvodeieq` (see nrn/src/init.cpp) --- src/codegen/codegen_naming.hpp | 3 +++ src/codegen/codegen_neuron_cpp_visitor.cpp | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/codegen/codegen_naming.hpp b/src/codegen/codegen_naming.hpp index 620e5d3592..397edc49e1 100644 --- a/src/codegen/codegen_naming.hpp +++ b/src/codegen/codegen_naming.hpp @@ -200,6 +200,9 @@ static constexpr char CVODE_SETUP_STIFF_NAME[] = "ode_setup_stiff"; /// name of CVODE method for setting up tolerances static constexpr char CVODE_SETUP_TOLERANCES_NAME[] = "ode_setup_tolerances"; +/// name of the CVODE variable (can be arbitrary) +static constexpr char CVODE_VARIABLE_NAME[] = "cvode_ieq"; + /// commonly used variables in verbatim block and how they /// should be mapped to new code generation backends // clang-format off diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 1ed2a62958..d772a57004 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -1431,7 +1431,8 @@ void CodegenNeuronCppVisitor::print_mechanism_register_regular() { if (info.emit_cvode) { mech_register_args.push_back( - fmt::format("_nrn_mechanism_field{{\"_cvode_ieq\", \"cvodeieq\"}} /* {} */", + fmt::format("_nrn_mechanism_field{{\"{}\", \"cvodeieq\"}} /* {} */", + naming::CVODE_VARIABLE_NAME, codegen_int_variables_size)); } From 25ff744a859de3e4c5dfdeb83c2c4e9079fc4506 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 28 Oct 2024 14:56:59 +0100 Subject: [PATCH 71/74] Remove `method_name` It already exists in `CodegenCppVisitor` --- src/codegen/codegen_neuron_cpp_visitor.cpp | 5 ----- src/codegen/codegen_neuron_cpp_visitor.hpp | 2 -- 2 files changed, 7 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index d772a57004..a22a82a661 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2646,11 +2646,6 @@ CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::cvode_update_parameters( } -std::string CodegenNeuronCppVisitor::method_name(const std::string& name) { - return fmt::format("{}_{}", name, info.mod_suffix); -} - - /* print the function returning the # of ODEs to solve */ void CodegenNeuronCppVisitor::print_cvode_count() { printer->fmt_push_block("static constexpr int {}(int _type)", diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index 5619cf9713..a4984b4ffb 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -792,8 +792,6 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_cvode_setup(const std::string& setup_name, const std::string& update_name); - std::string method_name(const std::string& name); - /****************************************************************************************/ /* Overloaded visitor routines */ From 3acc4b33e1995191191071c673448a8182e83c56 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 28 Oct 2024 15:05:02 +0100 Subject: [PATCH 72/74] Move method descriptions to header --- src/codegen/codegen_neuron_cpp_visitor.cpp | 4 ---- src/codegen/codegen_neuron_cpp_visitor.hpp | 7 ++++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index a22a82a661..ae4d49d7eb 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2646,7 +2646,6 @@ CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::cvode_update_parameters( } -/* print the function returning the # of ODEs to solve */ void CodegenNeuronCppVisitor::print_cvode_count() { printer->fmt_push_block("static constexpr int {}(int _type)", method_name(naming::CVODE_COUNT_NAME)); @@ -2655,7 +2654,6 @@ void CodegenNeuronCppVisitor::print_cvode_count() { printer->add_newline(2); } -/* print the function for setup of tolerances */ void CodegenNeuronCppVisitor::print_cvode_tolerances() { const CodegenNeuronCppVisitor::ParamVector tolerances_parameters = { {"", "Prop*", "", "_prop"}, @@ -2692,7 +2690,6 @@ void CodegenNeuronCppVisitor::print_cvode_tolerances() { printer->add_newline(2); } -/* print the CVODE update function (called ``name``) from ``block`` */ void CodegenNeuronCppVisitor::print_cvode_update(const std::string& name, const ast::StatementBlock& block) { printer->fmt_push_block("static int {}({})", @@ -2709,7 +2706,6 @@ void CodegenNeuronCppVisitor::print_cvode_update(const std::string& name, printer->add_newline(2); } -/* print the setup function (that calls an update function) */ void CodegenNeuronCppVisitor::print_cvode_setup(const std::string& setup_name, const std::string& update_name) { printer->fmt_push_block("static void {}({})", diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index a4984b4ffb..ea278b9799 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -784,12 +784,17 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { */ void print_cvode_definitions(); + /* print the CVODE function returning the # of ODEs to solve */ void print_cvode_count(); + /* print the CVODE function for setup of tolerances */ void print_cvode_tolerances(); - void print_cvode_update(const std::string& name, const ast::StatementBlock& node); + /* print the CVODE update function ``name`` from ``block`` */ + void print_cvode_update(const std::string& name, const ast::StatementBlock& block); + /* print the CVODE setup function ``setup_name`` that calls the CVODE update function + * ``update_name`` */ void print_cvode_setup(const std::string& setup_name, const std::string& update_name); From 546a695a04b905ed446d5dfdfe9e8e7532934996 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 29 Oct 2024 10:34:10 +0100 Subject: [PATCH 73/74] Update docstring format --- src/codegen/codegen_neuron_cpp_visitor.hpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index ea278b9799..99fbd8154d 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -784,17 +784,25 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { */ void print_cvode_definitions(); - /* print the CVODE function returning the # of ODEs to solve */ + /** + * Print the CVODE function returning the # of ODEs to solve + */ void print_cvode_count(); - /* print the CVODE function for setup of tolerances */ + /** + * Print the CVODE function for setup of tolerances + */ void print_cvode_tolerances(); - /* print the CVODE update function ``name`` from ``block`` */ + /** + * Print the CVODE update function \c name from \c block + */ void print_cvode_update(const std::string& name, const ast::StatementBlock& block); - /* print the CVODE setup function ``setup_name`` that calls the CVODE update function - * ``update_name`` */ + /** + * Print the CVODE setup function \c setup_name that calls the CVODE update function + * \c update_name + */ void print_cvode_setup(const std::string& setup_name, const std::string& update_name); From 9a769da83ed254cb3c9de64a9d7a6fe30c2712e7 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 29 Oct 2024 11:01:10 +0100 Subject: [PATCH 74/74] Fix wording --- src/codegen/codegen_neuron_cpp_visitor.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index 99fbd8154d..24d71c22fb 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -785,7 +785,7 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_cvode_definitions(); /** - * Print the CVODE function returning the # of ODEs to solve + * Print the CVODE function returning the number of ODEs to solve */ void print_cvode_count(); @@ -795,7 +795,7 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_cvode_tolerances(); /** - * Print the CVODE update function \c name from \c block + * Print the CVODE update function \c name contained in \c block */ void print_cvode_update(const std::string& name, const ast::StatementBlock& block);