Skip to content

Commit

Permalink
Fix custom function handling in sympy solve visitor (#1563)
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran authored Dec 3, 2024
1 parent a851a1b commit f4b5bc7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,12 @@ def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):
if _a1 == 0 and _a2 == 0:
solution = _a0

custom_fcts = {str(f.func): str(f.func) for f in solution.atoms(sp.Function)}

# return result as C code in NEURON format:
# - in the lhs x_0 refers to the state var at time (t+dt)
# - in the rhs x_0 refers to the state var at time t
return f"{sp.ccode(x)} = {sp.ccode(solution.evalf())}"
return f"{sp.ccode(x)} = {sp.ccode(solution.evalf(), user_functions=custom_fcts)}"


def forwards_euler2c(diff_string, dt_var, vars, function_calls):
Expand Down
2 changes: 2 additions & 0 deletions test/unit/ode/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def test_integrate2c():
("a", "x + a*dt"),
("a*x", "x*exp(a*dt)"),
("a*x+b", "(-b + (a*x + b)*exp(a*dt))/a"),
# assume custom_function is defined in mod file
("custom_function(a)*x", "x*exp(custom_function(a)*dt)"),
]
for eq, sol in test_cases:
assert _equivalent(
Expand Down

0 comments on commit f4b5bc7

Please sign in to comment.