Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add infrastructure for initialization of different problem types #885

Merged
merged 30 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ae62632
feat: add `constructorof` for `SDEProblem`
AayushSabharwal Dec 3, 2024
97e1725
feat: add `constructorof` for `SDDEProblem`
AayushSabharwal Dec 3, 2024
73c5d7d
feat: add `constructorof` for `DDEProblem`
AayushSabharwal Dec 3, 2024
6989157
feat: add proper remake for `SDEProblem`
AayushSabharwal Dec 3, 2024
0e2a089
feat: run `remake_initialization_data` when remaking `SDEProblem`
AayushSabharwal Dec 3, 2024
1980fa6
feat: add proper `remake` for `DDEProblem`
AayushSabharwal Dec 3, 2024
555cf5d
feat: add proper `remake` for `SDDEProblem`
AayushSabharwal Dec 3, 2024
d0b31ca
fix: support non-markovian index providers in `updated_u0_p`
AayushSabharwal Dec 3, 2024
abf7ff6
feat: implement `get_history_function` for `AbstractSDDEProblem`
AayushSabharwal Dec 3, 2024
071ddee
feat: implement `get_history_function` for `AbstractSDDEIntegrator`
AayushSabharwal Dec 3, 2024
0488aaf
build: bump SymbolicIndexingInterface compat
AayushSabharwal Dec 4, 2024
86d61b4
fix: fix type stability of `remake(::SDEProblem)`
AayushSabharwal Dec 4, 2024
c309693
test: test `remake` for `DDEProblem`, `SDDEProblem`
AayushSabharwal Dec 4, 2024
071723f
fix: add `.initializeprob` syntax to all applicable SciMLFunctions
AayushSabharwal Dec 4, 2024
68f74f2
refactor: update `remake_initializeprob` fallback
AayushSabharwal Dec 4, 2024
041c539
test: fix DDE indexing test
AayushSabharwal Dec 4, 2024
7bffc07
test: test SDDE indexing
AayushSabharwal Dec 4, 2024
b75e3a4
build: bump MTK compat in downstream CI
AayushSabharwal Dec 4, 2024
1624c01
feat: add lazy initialization to new `remake` methods
AayushSabharwal Dec 5, 2024
93483ef
feat: add `constructorof` for `NonlinearProblem`, `NonlinearLeastSqua…
AayushSabharwal Dec 6, 2024
8b8d8f2
feat: add proper `remake` for `NonlinearProblem`, `NonlinearLeastSqua…
AayushSabharwal Dec 6, 2024
2188104
fix: allow specifying `f` for `remake` of `SCCNonlinearProblem`
AayushSabharwal Dec 6, 2024
58f6852
fix: handle `initialization_data` in `f` passed to `remake`
AayushSabharwal Dec 9, 2024
fbcc39e
test: test lazy initialization in `remake` for supported problem types
AayushSabharwal Dec 9, 2024
0b8e460
fix: check if `initialization_data` exists before running eager initi…
AayushSabharwal Dec 9, 2024
665582e
test: add StochasticDelayDiffEq to downstream CI
AayushSabharwal Dec 9, 2024
dcde577
test: do not test unimplemented SDDE integrator stepping
AayushSabharwal Dec 9, 2024
089e31a
feat: generalize `get_history_function` to `AbstractODESolution`
AayushSabharwal Dec 10, 2024
ee118ef
feat: generalize `CheckInit` to DDEs
AayushSabharwal Dec 10, 2024
dfcb209
Update ensemble_nondes.jl
ChrisRackauckas Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ StableRNGs = "1.0"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.34"
SymbolicIndexingInterface = "0.3.36"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset
import Accessors: @set, @reset, @delete
using Expronicon.ADT: @match

using Reexport
Expand Down
25 changes: 23 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,27 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

"""
Utility function to evaluate the RHS, adding extra arguments (such as history function for
DDEs) wherever necessary.
"""
function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, p, t)
end

function evaluate_f(
integrator::DEIntegrator, prob::AbstractDAEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
end

function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -147,7 +168,7 @@ function get_initial_values(
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = isdefined(integrator.opts, :internalnorm) ?
Expand All @@ -165,7 +186,7 @@ function get_initial_values(
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(resid, t) : norm(resid)

Expand Down
3 changes: 2 additions & 1 deletion src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ function isadaptive(integrator::DEIntegrator)
isdefined(integrator.opts, :adaptive) ? integrator.opts.adaptive : false
end

function SymbolicIndexingInterface.get_history_function(integ::AbstractDDEIntegrator)
function SymbolicIndexingInterface.get_history_function(integ::Union{
AbstractDDEIntegrator, AbstractSDDEIntegrator})
DDESolutionHistoryWrapper(get_sol(integ))
end
13 changes: 13 additions & 0 deletions src/problems/dde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,19 @@ struct DDEProblem{uType, tType, lType, lType2, isinplace, P, F, H, K, PT} <:
end
end

function ConstructionBase.constructorof(::Type{P}) where {P <: DDEProblem}
function ctor(f, u0, h, tspan, p, constant_lags, dependent_lags,
kw, neutral, order_discontinuity_t0, problem_type)
if f isa AbstractDDEFunction
iip = isinplace(f)
else
iip = isinplace(f, 5)
end
return DDEProblem{iip}(f, u0, h, tspan, p; kw..., constant_lags, dependent_lags,
neutral, order_discontinuity_t0, problem_type)
end
end

DDEProblem(f, args...; kwargs...) = DDEProblem(DDEFunction(f), args...; kwargs...)

function DDEProblem(f::AbstractDDEFunction, args...; kwargs...)
Expand Down
22 changes: 22 additions & 0 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ function NonlinearProblem(f::AbstractODEFunction, u0, p = NullParameters(); kwar
NonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearProblem}
function ctor(f, u0, p, pt, kw)
if f isa AbstractNonlinearFunction
iip = isinplace(f)
else
iip = isinplace(f, 4)
end
return NonlinearProblem{iip}(f, u0, p, pt; kw...)
end
end

"""
$(SIGNATURES)

Expand Down Expand Up @@ -322,6 +333,17 @@ function NonlinearLeastSquaresProblem(f, u0, p = NullParameters(); kwargs...)
return NonlinearLeastSquaresProblem(NonlinearFunction(f), u0, p; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearLeastSquaresProblem}
function ctor(f, u0, p, kw)
if f isa AbstractNonlinearFunction
iip = isinplace(f)
else
iip = isinplace(f, 4)
end
return NonlinearProblem{iip}(f, u0, p; kw...)
end
end

@doc doc"""
SCCNonlinearProblem(probs, explicitfuns!)

Expand Down
16 changes: 16 additions & 0 deletions src/problems/sdde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,19 @@ end
function SDDEProblem(f::AbstractSDDEFunction, args...; kwargs...)
SDDEProblem{isinplace(f)}(f, args...; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: SDDEProblem}
function ctor(f, g, u0, h, tspan, p, noise, constant_lags, dependent_lags, kw,
noise_rate_prototype, seed, neutral, order_discontinuity_t0)
if f isa AbstractSDDEFunction
iip = isinplace(f)
else
iip = isinplace(f, 5)
end
return SDDEProblem{iip}(
f, g, u0, h, tspan, p; kw..., noise, constant_lags, dependent_lags,
noise_rate_prototype, seed, neutral, order_discontinuity_t0)
end
end

SymbolicIndexingInterface.get_history_function(prob::AbstractSDDEProblem) = prob.h
11 changes: 11 additions & 0 deletions src/problems/sde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
SDEProblem{iip}(SDEFunction{iip}(f, g), u0, tspan, p; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem}
function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed)
if f isa AbstractSDEFunction
iip = isinplace(f)
else
iip = isinplace(f, 4)
end
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
end
end

"""
$(TYPEDEF)
"""
Expand Down
Loading
Loading