From b183af996d38bc595b3c547aedeb000dbf114b0f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 15:34:12 +0530 Subject: [PATCH] feat: generalize `CheckInit` to DDEs --- src/initialization.jl | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 8b45bb6a6..1a1632c27 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -111,6 +111,26 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...) return f(args...) end +""" +Utility function to evaluate the RHS, adding extra arguments (such as history function for +DDEs) wherever necessary. +""" +function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, p, t) +end + +function evaluate_f(integrator::DEIntegrator, prob::AbtsractDAEProblem, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t) +end + +function evaluate_f(integrator::AbstractDDEIntegrator, prob, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) +end + +function evaluate_f(integrator::AbstractSDDEIntegrator, prob, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) +end + """ $(TYPEDSIGNATURES) @@ -147,7 +167,7 @@ function get_initial_values( algebraic_eqs = [all(iszero, x) for x in eachrow(M)] (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true update_coefficients!(M, u0, p, t) - tmp = _evaluate_f(integrator, f, isinplace, u0, p, t) + tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t) tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) normresid = isdefined(integrator.opts, :internalnorm) ? @@ -165,7 +185,7 @@ function get_initial_values( p = parameter_values(integrator) t = current_time(integrator) - resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t) + resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t) normresid = isdefined(integrator.opts, :internalnorm) ? integrator.opts.internalnorm(resid, t) : norm(resid)