From efdc4a09b2a035f6efb6b49dcaa6bfe5fbde03e4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Nov 2023 14:09:16 -0500 Subject: [PATCH] Update codebase --- .JuliaFormatter.toml | 4 +- LICENSE.md | 1 - Project.toml | 15 +-- README.md | 9 ++ src/SteadyStateDiffEq.jl | 6 +- src/algorithms.jl | 79 +++++++------- src/solve.jl | 221 +++++++++++++++++++-------------------- test/autodiff.jl | 4 +- test/core.jl | 111 ++++++++++++++++++++ test/runtests.jl | 143 ++----------------------- 10 files changed, 290 insertions(+), 303 deletions(-) create mode 100644 test/core.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 453925c..320e0c0 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1 +1,3 @@ -style = "sciml" \ No newline at end of file +style = "sciml" +format_markdown = true +annotate_untyped_fields_with_any = false diff --git a/LICENSE.md b/LICENSE.md index 0bea68e..d9703c3 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -19,4 +19,3 @@ The DiffEqSteadyState.jl package is licensed under the MIT "Expat" License: > LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, > OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE > SOFTWARE. -> diff --git a/Project.toml b/Project.toml index 6ccd51d..39567d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,29 +1,30 @@ name = "SteadyStateDiffEq" uuid = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -version = "1.16.1" +version = "2.0.0" [deps] +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" [compat] -DiffEqBase = "6.126" -DiffEqCallbacks = "2.9" -NLsolve = "4.2" +DiffEqBase = "6.140" +DiffEqCallbacks = "2.34" Reexport = "0.2, 1.0" -SciMLBase = "1.90, 2" +SciMLBase = "2" julia = "1.6" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["OrdinaryDiffEq", "ForwardDiff", "ModelingToolkit", "Test", "Sundials"] +test = ["OrdinaryDiffEq", "NonlinearSolve", "ForwardDiff", "ModelingToolkit", "Test", "SafeTestsets", "Sundials"] diff --git a/README.md b/README.md index 9680e01..88d385a 100644 --- a/README.md +++ b/README.md @@ -9,3 +9,12 @@ SteadyStateDiffEq.jl is a component package in the DifferentialEquations ecosyst It holds the steady state solvers for differential equations. While completely independent and usable on its own, users interested in using this functionality should check out [DifferentialEquations.jl](https://github.com/JuliaDiffEq/DifferentialEquations.jl). + +## Breaking Changes in v2 + + 1. `NLsolve.jl` dependency has been dropped. `SSRootfind` requires a nonlinear solver to be + specified. + 2. `DynamicSS` no longer stores `abstol` and `reltol`. To use separate tolerances for the + odesolve and the termination, specify `odesolve_kwargs` in `solve`. + 3. The deprecated termination conditions are dropped, see [NonlinearSolve.jl Docs](https://docs.sciml.ai/NonlinearSolve/stable/basics/TerminationCondition/) + for details on this. diff --git a/src/SteadyStateDiffEq.jl b/src/SteadyStateDiffEq.jl index 05609d5..9b7855b 100644 --- a/src/SteadyStateDiffEq.jl +++ b/src/SteadyStateDiffEq.jl @@ -3,13 +3,11 @@ module SteadyStateDiffEq using Reexport @reexport using DiffEqBase -using NLsolve, DiffEqCallbacks -using LinearAlgebra -using SciMLBase +using DiffEqCallbacks, ConcreteStructs, LinearAlgebra, SciMLBase include("algorithms.jl") include("solve.jl") export SSRootfind, DynamicSS -end # module +end diff --git a/src/algorithms.jl b/src/algorithms.jl index 751cc8c..6957e17 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -1,18 +1,24 @@ abstract type SteadyStateDiffEqAlgorithm <: DiffEqBase.AbstractSteadyStateAlgorithm end -struct SSRootfind{F} <: SteadyStateDiffEqAlgorithm - nlsolve::F -end -function SSRootfind(; - nlsolve = (f, u0, abstol) -> (NLsolve.nlsolve(f, u0, - ftol = abstol))) - SSRootfind(nlsolve) +""" + SSRootfind(alg = nothing) + +Use a Nonlinear Solver to find the steady state. Requires that a nonlinear solver is +given as the first argument. + +!!! note + + The default `alg` of `nothing` works only if `NonlinearSolve.jl` is installed and + loaded. +""" +@concrete struct SSRootfind <: SteadyStateDiffEqAlgorithm + alg end +SSRootfind() = SSRootfind(nothing) + """ - DynamicSS(alg; abstol = 1e-8, reltol = 1e-6, tspan = Inf, - termination_condition = SteadyStateTerminationCriteria(:default; abstol, - reltol)) + DynamicSS(alg = nothing; tspan = Inf) Requires that an ODE algorithm is given as the first argument. The absolute and relative tolerances specify the termination conditions on the derivative's closeness to @@ -24,45 +30,46 @@ Example usage: ```julia using SteadyStateDiffEq, OrdinaryDiffEq -sol = solve(prob,DynamicSS(Tsit5())) +sol = solve(prob, DynamicSS(Tsit5())) using Sundials -sol = solve(prob,DynamicSS(CVODE_BDF()),dt=1.0) +sol = solve(prob, DynamicSS(CVODE_BDF()); dt = 1.0) ``` !!! note - If you use `CVODE_BDF` you may need to give a starting `dt` via `dt=....`.* + The default `alg` of `nothing` works only if `DifferentialEquations.jl` is installed and + loaded. + +!!! note + + If you use `CVODE_BDF` you may need to give a starting `dt` via `dt = ....`.* """ -struct DynamicSS{A, AT, RT, TS, TC <: NLSolveTerminationCondition} <: - SteadyStateDiffEqAlgorithm - alg::A - abstol::AT - reltol::RT - tspan::TS - termination_condition::TC +@concrete struct DynamicSS <: SteadyStateDiffEqAlgorithm + alg + tspan end -function DynamicSS(alg; abstol = 1e-8, reltol = 1e-6, tspan = Inf, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.SteadyStateDefault; - abstol, - reltol)) - DynamicSS(alg, abstol, reltol, tspan, termination_condition) -end +DynamicSS(alg = nothing; tspan = Inf) = DynamicSS(alg, tspan) -# Backward compatibility: -DynamicSS(alg, abstol, reltol) = DynamicSS(alg; abstol = abstol, reltol = reltol) +function DiffEqBase.prepare_alg(alg::DynamicSS, u0, p, f) + return DynamicSS(DiffEqBase.prepare_alg(alg.alg, u0, p, f), alg.tspan) +end ## SciMLBase Trait Definitions +SciMLBase.isadaptive(::SSRootfind) = false -SciMLBase.isadaptive(alg::SteadyStateDiffEqAlgorithm) = true +for aType in (:SSRootfind, :DynamicSS), + op in (:isadaptive, :isautodifferentiable, :allows_arbitrary_number_types, + :allowscomplex) -SciMLBase.isautodifferentiable(alg::SSRootfind) = true -SciMLBase.allows_arbitrary_number_types(alg::SSRootfind) = true -SciMLBase.allowscomplex(alg::SSRootfind) = true + op == :isadaptive && aType == :SSRootfind && continue -SciMLBase.isautodifferentiable(alg::DynamicSS) = SciMLBase.isautodifferentiable(alg.alg) -function SciMLBase.allows_arbitrary_number_types(alg::DynamicSS) - SciMLBase.allows_arbitrary_number_types(alg.alg) + @eval function SciMLBase.$(op)(alg::$aType) + internal_alg = alg.alg + # Internal Alg is nothing means we will handle everything correctly downstream + internal_alg === nothing && return true + !hasmethod(SciMLBase.$(op), Tuple{typeof(internal_alg)}) && return false + return SciMLBase.$(op)(internal_alg) + end end -SciMLBase.allowscomplex(alg::DynamicSS) = SciMLBase.allowscomplex(alg.alg) diff --git a/src/solve.jl b/src/solve.jl index d5524c7..0537768 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,144 +1,133 @@ -function DiffEqBase.prepare_alg(alg::DynamicSS) - DynamicSS(DiffEqBase.prepare_alg(alg.alg), alg.abstol, alg.reltol, alg.tspan) +function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::SSRootfind, + args...; kwargs...) + nlprob = NonlinearProblem(prob) + nlsol = DiffEqBase.__solve(nlprob, alg.alg, args...; kwargs...) + return SciMLBase.build_solution(prob, SSRootfind(nlsol.alg), nlsol.u, nlsol.resid; + nlsol.retcode, nlsol.stats, nlsol.left, nlsol.right, original = nlsol) end -function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, - alg::SteadyStateDiffEqAlgorithm, args...; - abstol = 1e-8, kwargs...) - @warn """ - This method is deprecated in favor of using NonlinearSolve.jl. Note that an ODEProblem - can be converted into a steady state NonlinearProblem via - `NonlinearProblem(prob::ODEProblem)`. The algorithm `NLSolveJL` as part of the - SciMLNLSolve.jl set of nonlinear solvers for NonlinearSolve.jl is equivalent to - SteadyStateDiffEq.jl's default `SSRootfind` (with a few improvements). - - See [the documentation of NonlinearSolve.jl](https://docs.sciml.ai/NonlinearSolve/stable/) - for more details. - """ - - if prob.u0 isa Number - u0 = [prob.u0] - else - u0 = vec(deepcopy(prob.u0)) - end +__get_tspan(u0, alg::DynamicSS) = __get_tspan(u0, alg.tspan) +__get_tspan(u0, tspan::Tuple) = tspan +function __get_tspan(u0, tspan::Number) + return convert.(DiffEqBase.value(real(eltype(u0))), + (DiffEqBase.value(zero(tspan)), tspan)) +end - sizeu = size(prob.u0) - p = prob.p - - if prob isa SteadyStateProblem - if !isinplace(prob) && - (prob.u0 isa AbstractVector || prob.u0 isa Number) - f! = (du, u) -> (du[:] = prob.f(u, p, Inf); nothing) - elseif !isinplace(prob) && prob.u0 isa AbstractArray - f! = (du, u) -> (du[:] = vec(prob.f(reshape(u, sizeu), p, Inf)); nothing) - elseif prob.u0 isa AbstractVector - f! = (du, u) -> (prob.f(du, u, p, Inf); nothing) - else # Then it's an in-place function on an abstract array - f! = (du, u) -> (prob.f(reshape(du, sizeu), - reshape(u, sizeu), p, Inf); - du = vec(du); - nothing) - end +function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::DynamicSS, + args...; abstol = 1e-8, reltol = 1e-6, odesolve_kwargs = (;), + save_idxs = nothing, termination_condition = SteadyStateDiffEqTerminationMode(), + kwargs...) + tspan = __get_tspan(prob.u0, alg) + + f = if prob isa SteadyStateProblem + prob.f elseif prob isa NonlinearProblem - if !isinplace(prob) && - (prob.u0 isa AbstractVector || prob.u0 isa Number) - f! = (du, u) -> (du[:] = prob.f(u, p); nothing) - elseif !isinplace(prob) && prob.u0 isa AbstractArray - f! = (du, u) -> (du[:] = vec(prob.f(reshape(u, sizeu), p)); nothing) - elseif prob.u0 isa AbstractVector - f! = (du, u) -> (prob.f(du, u, p); nothing) - else # Then it's an in-place function on an abstract array - f! = (du, u) -> (prob.f(reshape(du, sizeu), - reshape(u, sizeu), p); - du = vec(du); - nothing) + if isinplace(prob) + (du, u, p, t) -> prob.f(du, u, p) + else + (u, p, t) -> prob.f(u, p) end end - # du = similar(u) - # f = (u) -> (f!(du,u); du) # out-of-place version - - if alg isa SSRootfind - original = alg.nlsolve(f!, u0, abstol) - if original isa NLsolve.SolverResults - u = reshape(original.zero, sizeu) - resid = similar(u) - f!(resid, u) - DiffEqBase.build_solution(prob, alg, u, resid; retcode = ReturnCode.Success, - original = original) - else - u = reshape(original, sizeu) - resid = similar(u) - f!(resid, u) - DiffEqBase.build_solution(prob, alg, u, resid; retcode = ReturnCode.Success) - end + if isinplace(prob) + du = similar(prob.u0) + f(du, prob.u0, prob.p, 0.0) else - error("Algorithm not recognized") + du = f(prob.u0, prob.p, 0.0) end -end -function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, - alg::DynamicSS, args...; save_everystep = false, - save_start = false, save_idxs = nothing, kwargs...) - tspan = alg.tspan isa Tuple ? alg.tspan : - convert.(DiffEqBase.value(real(eltype(prob.u0))), - (DiffEqBase.value(zero(alg.tspan)), alg.tspan)) - if prob isa SteadyStateProblem - f = prob.f - elseif prob isa NonlinearProblem - if isinplace(prob) - f = (du, u, p, t) -> prob.f(du, u, p) - else - f = (u, p, t) -> prob.f(u, p) - end + tc_cache = init(du, prob.u0, termination_condition, last(tspan); abstol, reltol) + abstol = DiffEqBase.get_abstol(tc_cache) + reltol = DiffEqBase.get_reltol(tc_cache) + + function terminate_function(u, t, integrator) + return tc_cache(get_du(integrator), integrator.u, integrator.uprev, t) end - mode = DiffEqBase.get_termination_mode(alg.termination_condition) + callback = TerminateSteadyState(abstol, reltol, terminate_function; + wrap_test = Val(false)) - storage = if mode ∈ DiffEqBase.SAFE_TERMINATION_MODES - NLSolveSafeTerminationResult() - else - nothing - end - callback = TerminateSteadyState(alg.termination_condition.abstol, - alg.termination_condition.reltol, - alg.termination_condition(storage)) + haskey(kwargs, :callback) && (callback = CallbackSet(callback, kwargs[:callback])) + haskey(odesolve_kwargs, :callback) && + (callback = CallbackSet(callback, odesolve_kwargs[:callback])) + + # Construct and solve the ODEProblem + odeprob = ODEProblem{isinplace(f)}(f, prob.u0, tspan, prob.p) + odesol = DiffEqBase.__solve(odeprob, alg.alg, args...; abstol, reltol, kwargs..., + odesolve_kwargs..., callback, save_end = true) - if haskey(kwargs, :callback) - callback = CallbackSet(callback, kwargs[:callback]) + resid, u, retcode = __get_result_from_sol(termination_condition, tc_cache, odesol) + + if save_idxs !== nothing + u = u[save_idxs] + resid = resid[save_idxs] end - _prob = ODEProblem(f, prob.u0, tspan, prob.p) - sol = solve(_prob, alg.alg, args...; kwargs..., save_everystep, save_start, callback) + return SciMLBase.build_solution(prob, DynamicSS(odesol.alg, alg.tspan), u, resid; + retcode, original = odesol) +end - idx, idx_prev = if storage === nothing || - !hasproperty(storage, :best_objective_value_iteration) - # weird hack but can't really help if save_everystep is turned off (also not - # relevant unless the user sets the mode to NLSolveDefault) - length(sol.u), (save_everystep ? length(sol.u) - 1 : 1) +function __get_result_from_sol(::DiffEqBase.AbstractNonlinearTerminationMode, tc_cache, + odesol) + u, t = last(odesol.u), last(odesol.t) + du = odesol(t, Val{1}) + return (du, u, + ifelse(odesol.retcode == ReturnCode.Terminated, ReturnCode.Success, + ReturnCode.Failure)) +end + +function __get_result_from_sol(::DiffEqBase.AbstractSafeNonlinearTerminationMode, tc_cache, + odesol) + u, t = last(odesol.u), last(odesol.t) + du = odesol(t, Val{1}) + + if tc_cache.retcode == DiffEqBase.NonlinearSafeTerminationReturnCode.Success + retcode_tc = ReturnCode.Success + elseif tc_cache.retcode == + DiffEqBase.NonlinearSafeTerminationReturnCode.PatienceTermination + retcode_tc = ReturnCode.ConvergenceFailure + elseif tc_cache.retcode == + DiffEqBase.NonlinearSafeTerminationReturnCode.ProtectiveTermination + retcode_tc = ReturnCode.Unstable else - max(storage.best_objective_value_iteration, 1), 1 # idx_prev is irrelevant + retcode_tc = ReturnCode.Default end - u, t, uprev = sol.u[idx], sol.t[idx], sol.u[idx_prev] - if isinplace(prob) - du = similar(sol.u[end]) - f(du, u, prob.p, t) + retcode = if odesol.retcode == ReturnCode.Terminated + ifelse(retcode_tc != ReturnCode.Default, retcode_tc, ReturnCode.Success) + elseif odesol.retcode == ReturnCode.Success + ReturnCode.Failure else - du = f(u, prob.p, t) + odesol.retcode end - retcode = sol.retcode == ReturnCode.Terminated && - DiffEqBase._has_converged(du, u, uprev, alg.termination_condition) ? - ReturnCode.Success : - ReturnCode.Failure + return du, u, retcode +end - if save_idxs !== nothing - u = sol.u[idx][save_idxs] - du = du[save_idxs] +function __get_result_from_sol(::DiffEqBase.AbstractSafeBestNonlinearTerminationMode, + tc_cache, odesol) + u, t = tc_cache.u, only(DiffEqBase.get_saved_values(tc_cache)) + du = odesol(t, Val{1}) + + if tc_cache.retcode == DiffEqBase.NonlinearSafeTerminationReturnCode.Success + retcode_tc = ReturnCode.Success + elseif tc_cache.retcode == + DiffEqBase.NonlinearSafeTerminationReturnCode.PatienceTermination + retcode_tc = ReturnCode.ConvergenceFailure + elseif tc_cache.retcode == + DiffEqBase.NonlinearSafeTerminationReturnCode.ProtectiveTermination + retcode_tc = ReturnCode.Unstable + else + retcode_tc = ReturnCode.Default + end + + retcode = if odesol.retcode == ReturnCode.Terminated + ifelse(retcode_tc != ReturnCode.Default, retcode_tc, ReturnCode.Success) + elseif odesol.retcode == ReturnCode.Success + ReturnCode.Failure else - u = sol.u[idx] + odesol.retcode end - return DiffEqBase.build_solution(prob, alg, u, du; retcode, sol.stats) + return du, u, retcode end diff --git a/test/autodiff.jl b/test/autodiff.jl index 0515b6e..9af298b 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -1,6 +1,4 @@ -using ModelingToolkit -using OrdinaryDiffEq -using ForwardDiff +using ModelingToolkit, OrdinaryDiffEq, ForwardDiff, Test, SteadyStateDiffEq using ForwardDiff: Dual @variables begin diff --git a/test/core.jl b/test/core.jl new file mode 100644 index 0000000..217810b --- /dev/null +++ b/test/core.jl @@ -0,0 +1,111 @@ +using SteadyStateDiffEq, + DiffEqBase, NonlinearSolve, Sundials, OrdinaryDiffEq, DiffEqCallbacks, Test + +function f(du, u, p, t) + du[1] = 2 - 2u[1] + du[2] = u[1] - 4u[2] +end + +u0 = zeros(2) +prob = SteadyStateProblem(f, u0) + +@testset "NonlinearSolve: $(nameof(typeof(alg)))" for alg in (nothing, + NewtonRaphson(; autodiff = AutoFiniteDiff()), KINSOL()) + sol = solve(prob, SSRootfind(alg)) + @test SciMLBase.successful_retcode(sol.retcode) + + du = zeros(2) + f(du, sol.u, nothing, 0) + @test maximum(du) < 1e-11 +end + +@testset "OrdinaryDiffEq" begin + du = zeros(2) + p = nothing + + sol = solve(prob, DynamicSS(Rodas5()); abstol = 1e-9, reltol = 1e-9) + @test SciMLBase.successful_retcode(sol.retcode) + + f(du, sol.u, p, 0) + @test du≈[0, 0] atol=1e-7 + + sol = solve(prob, DynamicSS(Rodas5(), tspan = 1e-3)) + @test sol.retcode != ReturnCode.Success + + sol = solve(prob, DynamicSS(CVODE_BDF()), dt = 1.0) + @test SciMLBase.successful_retcode(sol.retcode) + + # scalar save_idxs + scalar_sol = solve(prob, DynamicSS(CVODE_BDF()), dt = 1.0, save_idxs = 1) + @test scalar_sol[1] ≈ sol[1] + + f(du, sol.u, p, 0) + @test du≈[0, 0] atol=1e-6 +end + +# Float32 +u0 = [0.0f0, 0.0f0] + +function foop(u, p, t) + @test eltype(u) == eltype(u0) + dx = 2 - 2u[1] + dy = u[1] - 4u[2] + [dx, dy] +end + +function fiip(du, u, p, t) + @test eltype(u) == eltype(u0) + du[1] = 2 - 2u[1] + du[2] = u[1] - 4u[2] +end + +tspan = (0.0f0, 1.0f0) +proboop = SteadyStateProblem(foop, u0) +prob = SteadyStateProblem(fiip, u0) + +sol = solve(proboop, DynamicSS(Tsit5(), tspan = 1.0f-3)) +@test typeof(u0) == typeof(sol.u) +proboop = SteadyStateProblem(ODEProblem(foop, u0, tspan)) +sol2 = solve(proboop, DynamicSS(Tsit5()); abstol = 1e-4) +@test typeof(u0) == typeof(sol2.u) + +sol = solve(prob, DynamicSS(Tsit5(), tspan = 1.0f-3)) +@test typeof(u0) == typeof(sol.u) +prob = SteadyStateProblem(ODEProblem(fiip, u0, tspan)) +sol2 = solve(prob, DynamicSS(Tsit5()); abstol = 1e-4) +@test typeof(u0) == typeof(sol2.u) + +for termination_condition in [ + SteadyStateDiffEqTerminationMode(), SimpleNonlinearSolveTerminationMode(), + NormTerminationMode(), RelTerminationMode(), RelNormTerminationMode(), + AbsTerminationMode(), AbsNormTerminationMode(), RelSafeTerminationMode(), + AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode(), +] + sol_tc = solve(prob, DynamicSS(Tsit5()); termination_condition) + @show sol_tc.retcode, termination_condition + @test SciMLBase.successful_retcode(sol_tc.retcode) + @test sol_tc.u ≈ sol2.u +end + +# Complex u +u0 = [1.0im] + +function fcomplex(du, u, p, t) + du[1] = (0.1im - 1) * u[1] +end + +prob = SteadyStateProblem(ODEProblem(fcomplex, u0, (0.0, 1.0))) +sol = solve(prob, DynamicSS(Tsit5())) +@test SciMLBase.successful_retcode(sol.retcode) +@test abs(sol.u[end]) < 1e-8 + +# Callbacks +u0 = zeros(2) +prob = SteadyStateProblem(f, u0) +saved_values = SavedValues(Float64, Vector{Float64}) +cb = SavingCallback((u, t, integrator) -> copy(u), saved_values, save_everystep = true, + save_start = true) +sol = solve(prob, DynamicSS(Rodas5()); callback = cb, save_everystep = true, + save_start = true) +@test SciMLBase.successful_retcode(sol.retcode) +@test isapprox(saved_values.saveval[end], sol.u) diff --git a/test/runtests.jl b/test/runtests.jl index 93e4efb..51571de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,137 +1,10 @@ -using SteadyStateDiffEq, DiffEqBase, NLsolve, Sundials -using Test +using SafeTestsets, Test -function f(du, u, p, t) - du[1] = 2 - 2u[1] - du[2] = u[1] - 4u[2] +@testset "SteadyStateDiffEq.jl" begin + @safetestset "Core Tests" begin + include("core.jl") + end + @safetestset "Autodiff Tests" begin + include("autodiff.jl") + end end -u0 = zeros(2) -prob = SteadyStateProblem(f, u0) -abstol = 1e-8 -sol = solve(prob, SSRootfind()) -@test sol.retcode == ReturnCode.Success -p = nothing - -du = zeros(2) -f(du, sol.u, nothing, 0) -@test maximum(du) < 1e-11 - -prob = ODEProblem(f, u0, (0.0, 1.0)) -prob = SteadyStateProblem(prob) -sol = solve(prob, - SSRootfind(nlsolve = (f, u0, abstol) -> (NLsolve.nlsolve(f, u0, - autodiff = :forward, - method = :newton, - iterations = Int(1e6), - ftol = abstol)))) -@test sol.retcode == ReturnCode.Success -@test sol.original isa NLsolve.SolverResults - -f(du, sol.u, nothing, 0) -@test du == [0, 0] - -# Use Sundials -sol = solve(prob, SSRootfind(nlsolve = (f, u0, abstol) -> (res = Sundials.kinsol(f, u0)))) -@test sol.retcode == ReturnCode.Success -f(du, sol.u, nothing, 0) -@test du == [0, 0] - -using OrdinaryDiffEq -sol = solve(prob, DynamicSS(Rodas5())) -@test sol.retcode == ReturnCode.Success - -f(du, sol.u, p, 0) -@test du≈[0, 0] atol=1e-7 - -sol = solve(prob, DynamicSS(Rodas5(), tspan = 1e-3)) -@test sol.retcode != ReturnCode.Success - -sol = solve(prob, DynamicSS(CVODE_BDF()), dt = 1.0) -@test sol.retcode == ReturnCode.Success - -# scalar save_idxs -scalar_sol = solve(prob, DynamicSS(CVODE_BDF()), dt = 1.0, save_idxs = 1) -@test scalar_sol[1] ≈ sol[1] - -f(du, sol.u, p, 0) -@test du≈[0, 0] atol=1e-6 - -# Float32 -u0 = [0.0f0, 0.0f0] - -function foop(u, p, t) - @test eltype(u) == eltype(u0) - dx = 2 - 2u[1] - dy = u[1] - 4u[2] - [dx, dy] -end - -function fiip(du, u, p, t) - @test eltype(u) == eltype(u0) - du[1] = 2 - 2u[1] - du[2] = u[1] - 4u[2] -end - -tspan = (0.0f0, 1.0f0) -proboop = SteadyStateProblem(foop, u0) -prob = SteadyStateProblem(fiip, u0) - -sol = solve(proboop, DynamicSS(Tsit5(), tspan = 1.0f-3)) -@test typeof(u0) == typeof(sol.u) -proboop = SteadyStateProblem(ODEProblem(foop, u0, tspan)) -sol2 = solve(proboop, DynamicSS(Tsit5(), abstol = 1e-4)) -@test typeof(u0) == typeof(sol2.u) - -sol = solve(prob, DynamicSS(Tsit5(), tspan = 1.0f-3)) -@test typeof(u0) == typeof(sol.u) -prob = SteadyStateProblem(ODEProblem(fiip, u0, tspan)) -sol2 = solve(prob, DynamicSS(Tsit5(), abstol = 1e-4)) -@test typeof(u0) == typeof(sol2.u) - -for mode in instances(NLSolveTerminationMode.T) - mode == NLSolveTerminationMode.NLSolveDefault && continue - - termination_condition = NLSolveTerminationCondition(mode; abstol = 1e-4, reltol = 1e-4) - sol = solve(prob, - DynamicSS(Tsit5(); abstol = 1e-4, reltol = 1e-4, termination_condition), - save_everystep = mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES) - - @test sol.retcode == ReturnCode.Success - @test sol.u ≈ sol2.u - - sol = solve(proboop, - DynamicSS(Tsit5(); abstol = 1e-4, reltol = 1e-4, termination_condition), - save_everystep = mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES) - - @test sol.retcode == ReturnCode.Success - @test sol.u ≈ sol2.u -end - -# Complex u -u0 = [1.0im] - -function fcomplex(du, u, p, t) - du[1] = (0.1im - 1) * u[1] -end - -prob = SteadyStateProblem(ODEProblem(fcomplex, u0, (0.0, 1.0))) -sol = solve(prob, DynamicSS(Tsit5())) -@test sol.retcode == ReturnCode.Success -@test abs(sol.u[end]) < 1e-8 - -# Callbacks -using DiffEqCallbacks -u0 = zeros(2) -prob = SteadyStateProblem(f, u0) -saved_values = SavedValues(Float64, Vector{Float64}) -cb = SavingCallback((u, t, integrator) -> copy(u), saved_values, - save_everystep = true, save_start = true) -sol = solve(prob, - DynamicSS(Rodas5()), - callback = cb, - save_everystep = true, - save_start = true) -@test sol.retcode == ReturnCode.Success -@test isapprox(saved_values.saveval[end], sol.u) - -include("autodiff.jl")