Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Training Strategies to dae_solvers.jl #838

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
280a43d
Update dae_solve.jl
hippyhippohops Mar 26, 2024
911b3a3
Update dae_solve.jl
hippyhippohops Mar 26, 2024
9d98ae1
Update NNDAE_tests.jl
hippyhippohops Mar 28, 2024
35700c4
Added strategy::QuadratureTraining
hippyhippohops Mar 28, 2024
5773e91
Formatted indentation in strategy::WeightIntervalTraining
hippyhippohops Mar 28, 2024
c9f54de
Formatted Indentation in strategy::QuadratureTraining
hippyhippohops Mar 28, 2024
abc61bb
Refactored generate_losses in ode_solve.jl
hippyhippohops Mar 31, 2024
e4f06d5
Reverted back the ode_solve.jl to the previous set of codes
hippyhippohops Apr 1, 2024
0f4cfde
Edits to dae_solve.jl and NNDAE_tests.jl
hippyhippohops Apr 9, 2024
911d68c
Modified dae_solve.jl and NNDAE_tests
hippyhippohops Apr 12, 2024
2f4c505
Removed param_estim
hippyhippohops Apr 21, 2024
52cdea8
Update dae_solve.jl
hippyhippohops Apr 25, 2024
95457b9
Merge branch 'SciML:master' into patch-1
hippyhippohops Apr 30, 2024
d6f2e5f
Reset the code to match master code. Planning to start from scratch a…
hippyhippohops May 3, 2024
1ed7683
Implemented WeightedIntervalTraining and it's Test
hippyhippohops May 6, 2024
2f9db68
Formatted Code
hippyhippohops May 6, 2024
c2453d2
Added in failed Quadature training
hippyhippohops May 8, 2024
7c6c2bf
trying to workout quadature training strategy.
hippyhippohops May 16, 2024
70e0657
Stochastic training passes
hippyhippohops May 16, 2024
0098c6d
updates on NNDAE_tests.jl
hippyhippohops May 26, 2024
3e9473e
Updates
hippyhippohops May 26, 2024
92ec11c
Merge branch 'SciML:master' into patch-1
hippyhippohops Jun 4, 2024
b00c8cf
removing empty line
hippyhippohops Jun 4, 2024
afd05ee
Merge branch 'patch-1' of https://github.com/hippyhippohops/NeuralPDE…
hippyhippohops Jun 6, 2024
a4e2877
changes to quadrature training
hippyhippohops Jun 7, 2024
41dbf62
Added Quadrature training
hippyhippohops Jun 10, 2024
9c72fee
Changing to float64
hippyhippohops Jul 8, 2024
47b5aea
Merge branch 'SciML:master' into patch-1
hippyhippohops Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff =
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
end


function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ,
autodiff::Bool, differential_vars::AbstractVector) where {C, T, U <: Number}
if autodiff
ForwardDiff.derivative(t -> phi(t, θ), t)
else
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))
end
end

function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ,
autodiff::Bool,differential_vars::AbstractVector) where {C, T, U <: AbstractVector}
if autodiff
ForwardDiff.jacobian(t -> phi(t, θ), t)
else
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))
end
end

function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool,
differential_vars::AbstractVector)
if autodiff
Expand All @@ -69,6 +88,19 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
sum(abs2, loss) / length(t)
end

#=
function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, differential_vars::AbstractVector) where {C, T, U}
sum(abs2, dfdx(phi, t, θ, autodiff,differential_vars) .- f(phi(t, θ), t))
end
=#

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, differential_vars::AbstractVector) where {C, T, U}
dphi = dfdx(phi, t, θ, autodiff,differential_vars)
sum(abs2, f(dphi, phi(t, θ), p, t))
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
ts = tspan[1]:(strategy.dx):tspan[2]
Expand All @@ -79,6 +111,65 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function generate_loss(
strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data

function loss(θ, _)
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end


function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, differential_vars))

function integrand(ts, θ)
[sum(abs2, inner_loss(phi, f, autodiff, t, θ, p, differential_vars)) for t in ts]
end

function loss(θ, _)
intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol,
reltol = strategy.reltol, maxiters = strategy.maxiters)
sol.u
end
return loss
end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end

function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
alg::NNDAE,
args...;
Expand Down Expand Up @@ -136,8 +227,13 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
if dt !== nothing
GridTraining(dt)
else
error("dt is not defined")
QuadratureTraining(; quadrature_alg = QuadGKJL(),
reltol = convert(eltype(u0), reltol),
abstol = convert(eltype(u0), abstol), maxiters = maxiters,
batch = 0)
end
else
alg.strategy
end

inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars)
Expand Down
4 changes: 4 additions & 0 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ end

Representation of the loss function, parametric on the training strategy `strategy`.
"""

function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch, param_estim::Bool)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim))
Expand Down Expand Up @@ -304,6 +305,8 @@ function generate_loss(
return loss
end



function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch, param_estim::Bool)
function loss(θ, _)
if batch
Expand All @@ -319,6 +322,7 @@ function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, ts
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end


struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
θ::T2
Expand Down
105 changes: 95 additions & 10 deletions test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Random.seed!(100)
M = [1.0 0
0 0]
f = ODEFunction(example1, mass_matrix = M)
tspan = (0.0f0, 1.0f0)
tspan = (0.0, 1.0)

prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)
Expand All @@ -25,13 +25,13 @@ Random.seed!(100)
differential_vars = [true, false]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, cos), Lux.Dense(15, 15, sin), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNDAE(chain, opt; autodiff = false)
opt = OptimizationOptimJL.BFGS(linesearch = BackTracking())
alg = NNDAE(chain, opt; autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
@test ground_sol(0:(1 / 100):1)sol atol=0.4
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)
@test reduce(hcat, ground_sol(0:(1 / 100):1).u)≈reduce(hcat, sol.u) rtol=1e-1
end

@testset "Example 2" begin
Expand All @@ -44,7 +44,7 @@ end
0 1]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0f0, pi / 2.0f0)
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)
Expand All @@ -57,8 +57,93 @@ end
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)

@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2
end

@testset "WeightedIntervalTraining" begin
function example2(du, u, p, t)
du[1] = u[1] - t
du[2] = u[2] - t
nothing
end
M = [0.0 0.0
0.0 1.0]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
weights = [0.7, 0.2, 0.1]
points = 200
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1),
strategy = WeightedIntervalTraining(weights, points); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)

@test ground_sol(0:(1 / 100):(pi / 2))≈sol atol=0.4
@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2
end

@testset "StochasticTraining" begin
function example2(du, u, p, t)
du[1] = u[1] - t
du[2] = u[2] - t
nothing
end
M = [0.0 0.0
0.0 1.0]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1),
strategy = NeuralPDE.StochasticTraining(1000); autodiff = false)
sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)
@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2
end

@testset "QuadratureTraining" begin
function example2(du, u, p, t)
du[1] = u[1] - t
du[2] = u[2] - t
nothing
end
M = [0.0 0.0
0.0 1.0]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimJL.BFGS(linesearch = BackTracking())
alg = NeuralPDE.NNDAE(chain, opt; autodiff = false)
sol = solve(prob, alg, verbose = true, maxiters = 6000, abstol = 1e-10, dt = 1/100.0)
@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2
end