Skip to content

Commit

Permalink
Implemented WeightedIntervalTraining and it's Test
Browse files Browse the repository at this point in the history
  • Loading branch information
hippyhippohops committed May 6, 2024
1 parent d6f2e5f commit 1ed7683
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
31 changes: 31 additions & 0 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,35 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT-minT)/N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data

function loss(θ, _)
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end


function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
alg::NNDAE,
args...;
Expand Down Expand Up @@ -138,6 +167,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
else
error("dt is not defined")
end
else
alg.strategy
end

inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars)
Expand Down
33 changes: 32 additions & 1 deletion test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,38 @@ end
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false)
alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)

@test ground_sol(0:(1 / 100):(pi / 2))sol atol=0.4
end

@testset "WeightedIntervalTraining" begin
function example2(du, u, p, t)
du[1] = u[1] - t
du[2] = u[2] - t
nothing
end
M = [0.0 0
0 1]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0f0, pi / 2.0f0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
weights = [0.7, 0.2, 0.1]
points = 200
alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1), strategy = NeuralPDE.WeightedIntervalTraining(weights, points); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
Expand Down

0 comments on commit 1ed7683

Please sign in to comment.