From 1ed7683952705f51b19992d6ce68852259605ee5 Mon Sep 17 00:00:00 2001 From: hippyhippohops Date: Sun, 5 May 2024 23:01:25 -0500 Subject: [PATCH] Implemented WeightedIntervalTraining and it's Test --- src/dae_solve.jl | 31 +++++++++++++++++++++++++++++++ test/NNDAE_tests.jl | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 5a5ee83be3..df18ad08d9 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -79,6 +79,35 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, return loss end +function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, + differential_vars::AbstractVector) + autodiff && throw(ArgumentError("autodiff not supported for GridTraining.")) + minT = tspan[1] + maxT = tspan[2] + + weights = strategy.weights ./ sum(strategy.weights) + + N = length(weights) + points = strategy.points + + difference = (maxT-minT)/N + + data = Float64[] + for (index, item) in enumerate(weights) + temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+ + ((index - 1) * difference) + data = append!(data, temp_data) + end + + ts = data + + function loss(θ, _) + sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) + end + return loss +end + + function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, alg::NNDAE, args...; @@ -138,6 +167,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, else error("dt is not defined") end + else + alg.strategy end inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars) diff --git a/test/NNDAE_tests.jl b/test/NNDAE_tests.jl index 7199d190c7..e5930ec063 100644 --- a/test/NNDAE_tests.jl +++ b/test/NNDAE_tests.jl @@ -54,7 +54,38 @@ end prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2)) opt = OptimizationOptimisers.Adam(0.1) - alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false) + alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false) + + sol = solve(prob, + alg, verbose = false, dt = 1 / 100.0f0, + maxiters = 3000, abstol = 1.0f-10) + + @test ground_sol(0:(1 / 100):(pi / 2))≈sol atol=0.4 +end + +@testset "WeightedIntervalTraining" begin + function example2(du, u, p, t) + du[1] = u[1] - t + du[2] = u[2] - t + nothing + end + M = [0.0 0 + 0 1] + u₀ = [0.0, 0.0] + du₀ = [0.0, 0.0] + tspan = (0.0f0, pi / 2.0f0) + f = ODEFunction(example2, mass_matrix = M) + prob_mm = ODEProblem(f, u₀, tspan) + ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) + + example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]] + differential_vars = [false, true] + prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) + chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2)) + opt = OptimizationOptimisers.Adam(0.1) + weights = [0.7, 0.2, 0.1] + points = 200 + alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1), strategy = NeuralPDE.WeightedIntervalTraining(weights, points); autodiff = false) sol = solve(prob, alg, verbose = false, dt = 1 / 100.0f0,