diff --git a/Project.toml b/Project.toml index d76565ff25..45468985c5 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ AdvancedHMC = "0.6.1" Aqua = "0.8" ArrayInterface = "7.7" CUDA = "5.2" -ChainRulesCore = "1.18" +ChainRulesCore = "1.21" ComponentArrays = "0.15.8" Cubature = "1.5" DiffEqBase = "6.144" @@ -59,7 +59,7 @@ Integrals = "4" LineSearches = "7.2" LinearAlgebra = "1" LogDensityProblems = "2" -Lux = "0.5.14" +Lux = "0.5.22" LuxCUDA = "0.3.2" MCMCChains = "6" MethodOfLines = "0.10.7" @@ -82,7 +82,7 @@ SymbolicUtils = "1.4" Symbolics = "5.17" Test = "1" UnPack = "1" -Zygote = "0.6.68" +Zygote = "0.6.69" julia = "1.10" [extras] @@ -91,12 +91,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" [targets] test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"] diff --git a/docs/src/tutorials/neural_adapter.md b/docs/src/tutorials/neural_adapter.md index a56e30a269..93f0dd036f 100644 --- a/docs/src/tutorials/neural_adapter.md +++ b/docs/src/tutorials/neural_adapter.md @@ -69,7 +69,7 @@ function loss(cord, θ) ch2 .- phi(cord, res.u) end -strategy = NeuralPDE.QuadratureTraining() +strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6) prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000) @@ -173,7 +173,7 @@ for i in 1:count_decomp bcs_ = create_bcs(domains_[1].domain, phi_bound) @named pde_system_ = PDESystem(eq, bcs_, domains_, [x, y], [u(x, y)]) push!(pde_system_map, pde_system_) - strategy = NeuralPDE.QuadratureTraining() + strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6) discretization = NeuralPDE.PhysicsInformedNN(chains[i], strategy; init_params = init_params[i]) @@ -243,10 +243,10 @@ callback = function (p, l) end prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map, - NeuralPDE.QuadratureTraining()) + NeuralPDE.QuadratureTraining(; reltol = 1e-6)) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000) prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map, - NeuralPDE.QuadratureTraining()) + NeuralPDE.QuadratureTraining(; reltol = 1e-6)) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000) phi_ = PhysicsInformedNN(chain2, strategy; init_params = res_.u).phi diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 3bbf1afea8..087b9d41cc 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -113,7 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) BNNODE(chain, Kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd, dataset, physdt, MCMCkwargs, diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index d367bf8b6c..2ba1de25b2 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -30,6 +30,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays +using Lux: FromFluxAdaptor using ChainRulesCore: @non_differentiable RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 6f30149257..c86c87599c 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -439,7 +439,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMCkwargs = (n_leapfrog = 30,), progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) # NN parameter prior mean and variance(PriorsNN must be a tuple) if isinplace(prob) throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 3f6bf8f0fb..0c9d1323de 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -42,7 +42,7 @@ end function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) end diff --git a/src/ode_solve.jl b/src/ode_solve.jl index f93183d76f..8af6a708d9 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -15,7 +15,7 @@ of the physics-informed neural network which is used as a solver for a standard ## Positional Arguments * `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`. - `Flux.Chain` will be converted to `Lux` using `Lux.transform`. + `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`. * `opt`: The optimizer to train the neural network. * `init_params`: The initial parameter of the neural network. By default, this is `nothing` which thus uses the random initialization provided by the neural network library. @@ -27,11 +27,10 @@ of the physics-informed neural network which is used as a solver for a standard the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). -* `batch`: The batch size to use for the internal quadrature. Defaults to `0`, which - means the application of the neural network is done at individual time points one - at a time. `batch>0` means the neural network is applied at a row vector of values - `t` simultaneously, i.e. it's the batch size for the neural network evaluations. - This requires a neural network compatible with batched data. +* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values + `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. + `false` means which means the application of the neural network is done at individual time points one at a time. + This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand. * `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network. * `strategy`: The training strategy used to choose the points for the evaluations. Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no @@ -88,8 +87,8 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = nothing, param_estim = false, additional_loss = nothing, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end @@ -111,11 +110,7 @@ end function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) θ, st = Lux.setup(Random.default_rng(), chain) - if init_params === nothing - init_params = ComponentArrays.ComponentArray(θ) - else - init_params = ComponentArrays.ComponentArray(init_params) - end + isnothing(init_params) && (init_params = θ) ODEPhi(chain, t, u0, st), init_params end @@ -182,7 +177,7 @@ function ode_dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool) end """ - inner_loss(phi, f, autodiff, t, θ, p) + inner_loss(phi, f, autodiff, t, θ, p, param_estim) Simple L2 inner loss at a time `t` with parameters `θ` of the neural network. """ @@ -220,7 +215,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, end """ - generate_loss(strategy, phi, f, autodiff, tspan, p, batch) + generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) Representation of the loss function, parametric on the training strategy `strategy`. """ @@ -229,14 +224,13 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts] - @assert batch == 0 # not implemented function loss(θ, _) - intprob = IntegralProblem(integrand, (tspan[1], tspan[2]), θ) - sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol) + intf = BatchIntegralFunction(integrand, max_batch = strategy.batch) + intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ) + sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, reltol = strategy.reltol, maxiters = strategy.maxiters) sol.u end - return loss end @@ -395,16 +389,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, alg.strategy end - batch = if alg.batch === nothing - if strategy isa QuadratureTraining - strategy.batch - else - true - end - else - alg.batch - end - + batch = alg.batch inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) additional_loss = alg.additional_loss (param_estim && isnothing(additional_loss)) && throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true).")) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 69116c4da3..50f7649dc6 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -48,7 +48,7 @@ methodology. * `chain`: a vector of Lux/Flux chains with a d-dimensional input and a 1-dimensional output corresponding to each of the dependent variables. Note that this specification respects the order of the dependent variables as specified in the PDESystem. - Flux chains will be converted to Lux internally using `Lux.transform`. + Flux chains will be converted to Lux internally using `adapt(FromFluxAdaptor(false, false), chain)`. * `strategy`: determines which training strategy will be used. See the Training Strategy documentation for more details. @@ -107,7 +107,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput @@ -243,7 +243,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 1db8780941..c997e6c4cc 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -272,7 +272,7 @@ struct QuadratureTraining{Q <: SciMLBase.AbstractIntegralAlgorithm, T} <: batch::Int64 end -function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-6, abstol = 1e-3, +function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-3, abstol = 1e-6, maxiters = 1_000, batch = 100) QuadratureTraining(quadrature_alg, reltol, abstol, maxiters, batch) end @@ -306,11 +306,7 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature end area = eltypeθ(prod(abs.(ub .- lb))) f_ = (lb, ub, loss_, θ) -> begin - # last_x = 1 function integrand(x, θ) - # last_x = x - # mean(abs2,loss_(x,θ), dims=2) - # size_x = fill(size(x)[2],(1,1)) x = adapt(parameterless_type(ComponentArrays.getdata(θ)), x) sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x end diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 122adcceb3..1e2ba3c055 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -9,6 +9,7 @@ Random.seed!(100) @testset "Scalar" begin # Run a solve on scalars + println("Scalar") linear = (u, p, t) -> cos(2pi * t) tspan = (0.0f0, 1.0f0) u0 = 0.0f0 @@ -16,26 +17,27 @@ Random.seed!(100) luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95)) - sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = true, + sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) @test_throws ArgumentError solve(prob, NNODE(luxchain, opt; autodiff = true), dt = 1 / 20.0f0, - verbose = true, abstol = 1.0f-10, maxiters = 200) + verbose = false, abstol = 1.0f-10, maxiters = 200) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, abstol = 1.0f-6, maxiters = 200) opt = OptimizationOptimJL.BFGS() - sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = true, + sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, abstol = 1.0f-6, maxiters = 200) end @testset "Vector" begin # Run a solve on vectors + println("Vector") linear = (u, p, t) -> [cos(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0] @@ -44,14 +46,14 @@ end opt = OptimizationOptimJL.BFGS() sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, abstol = 1e-10, - verbose = true, maxiters = 200) + verbose = false, maxiters = 200) @test_throws ArgumentError solve(prob, NNODE(luxchain, opt; autodiff = true), dt = 1 / 20.0f0, - abstol = 1e-10, verbose = true, maxiters = 200) + abstol = 1e-10, verbose = false, maxiters = 200) sol = solve(prob, NNODE(luxchain, opt), abstol = 1.0f-6, - verbose = true, maxiters = 200) + verbose = false, maxiters = 200) @test sol(0.5) isa Vector @test sol(0.5; idxs = 1) isa Number @@ -59,6 +61,7 @@ end end @testset "Example 1" begin + println("Example 1") linear = (u, p, t) -> @. t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) - u * (t + ((1 + 3 * (t^2)) / (1 + t + t^3))) linear_analytic = (u0, p, t) -> [exp(-(t^2) / 2) / (1 + t + t^3) + t^2] @@ -66,75 +69,70 @@ end luxchain = Lux.Chain(Lux.Dense(1, 128, Lux.σ), Lux.Dense(128, 1)) opt = OptimizationOptimisers.Adam(0.01) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, maxiters = 400) + sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 - @test_throws AssertionError solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, - maxiters = 400) - sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400) + verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400) + verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 2" begin + println("Example 2") linear = (u, p, t) -> -u / 5 + exp(-t / 5) .* cos(t) linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t)) prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0)) luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, maxiters = 400, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 - @test_throws AssertionError solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, - maxiters = 400, - abstol = 1.0f-8) - sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400, + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400, + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 3" begin + println("Example 3") linear = (u, p, t) -> [cos(2pi * t), sin(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0, -1.0f0 / 2pi] @@ -146,13 +144,14 @@ end alg = NNODE(luxchain, opt; autodiff = false) sol = solve(prob, - alg, verbose = true, dt = 1 / 40.0f0, + alg, verbose = false, dt = 1 / 40.0f0, maxiters = 2000, abstol = 1.0f-7) @test sol.errors[:l2] < 0.5 end @testset "Training Strategies" begin @testset "WeightedIntervalTraining" begin + println("WeightedIntervalTraining") function f(u, p, t) [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] end @@ -169,7 +168,7 @@ end points = 200 alg = NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01) + sol = solve(prob_oop, alg, verbose = false, maxiters = 5000, saveat = 0.01) @test abs(mean(sol) - mean(true_sol)) < 0.2 end @@ -183,6 +182,7 @@ end u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x) @testset "GridTraining" begin + println("GridTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -190,22 +190,24 @@ end end alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01), additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-8, maxiters = 500) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end @testset "QuadratureTraining" begin + println("QuadratureTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-10, maxiters = 200) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end @testset "StochasticTraining" begin + println("StochasticTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -213,12 +215,13 @@ end end alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000), additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-8, maxiters = 500) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end end @testset "Parameter Estimation" begin + println("Parameter Estimation") function lorenz(u, p, t) return [p[1]*(u[2]-u[1]), u[1]*(p[2]-u[3])-u[2], @@ -240,14 +243,15 @@ end Lux.Dense(n, n, Lux.σ), Lux.Dense(n, 3) ) - opt = OptimizationOptimJL.LBFGS(linesearch = BackTracking()) + opt = OptimizationOptimJL.BFGS(linesearch = BackTracking()) alg = NNODE(luxchain, opt, strategy = GridTraining(0.01), param_estim = true, additional_loss = additional_loss) - sol = solve(prob, alg, verbose = true, abstol = 1e-8, maxiters = 5000, saveat = t_) + sol = solve(prob, alg, verbose = false, abstol = 1e-8, maxiters = 1000, saveat = t_) @test sol.k.u.p≈true_p atol=1e-2 @test reduce(hcat, sol.u)≈u_ atol=1e-2 end @testset "Translating from Flux" begin + println("Translating from Flux") linear = (u, p, t) -> cos(2pi * t) linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t) tspan = (0.0, 1.0) @@ -259,6 +263,6 @@ end fluxchain = Flux.Chain(Flux.Dense(1, 5, Flux.σ), Flux.Dense(5, 1)) alg1 = NNODE(fluxchain, opt) @test alg1.chain isa Lux.AbstractExplicitLayer - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-10, maxiters = 200) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end diff --git a/test/NNODE_tstops_test.jl b/test/NNODE_tstops_test.jl index c0f8422a09..bc4a4b08d6 100644 --- a/test/NNODE_tstops_test.jl +++ b/test/NNODE_tstops_test.jl @@ -31,46 +31,55 @@ points = 3 dx = 1.0 @testset "GridTraining" begin + println("GridTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end @testset "WeightedIntervalTraining" begin + println("WeightedIntervalTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end @testset "StochasticTraining" begin + println("StochasticTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end