From 2a72fa8c58ac8736559cfcda739738bd357b0df1 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Thu, 21 Mar 2024 06:56:28 +0000 Subject: [PATCH 1/8] refactor: correctly lower quadrature training strategy in NNODE --- src/ode_solve.jl | 35 +++++++++++------------------------ src/training_strategies.jl | 6 +----- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index f93183d76f..b9c46d3463 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -27,11 +27,12 @@ of the physics-informed neural network which is used as a solver for a standard the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). -* `batch`: The batch size to use for the internal quadrature. Defaults to `0`, which +* `batch`: The batch size for the loss computation. Defaults to `false`, which means the application of the neural network is done at individual time points one - at a time. `batch>0` means the neural network is applied at a row vector of values + at a time. `true` means the neural network is applied at a row vector of values `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. + This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand. * `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network. * `strategy`: The training strategy used to choose the points for the evaluations. Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no @@ -88,7 +89,7 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = nothing, param_estim = false, additional_loss = nothing, kwargs...) + autodiff = false, batch = false, param_estim = false, additional_loss = nothing, kwargs...) !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end @@ -111,11 +112,7 @@ end function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) θ, st = Lux.setup(Random.default_rng(), chain) - if init_params === nothing - init_params = ComponentArrays.ComponentArray(θ) - else - init_params = ComponentArrays.ComponentArray(init_params) - end + isnothing(init_params) && (init_params = θ) ODEPhi(chain, t, u0, st), init_params end @@ -182,7 +179,7 @@ function ode_dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool) end """ - inner_loss(phi, f, autodiff, t, θ, p) + inner_loss(phi, f, autodiff, t, θ, p, param_estim) Simple L2 inner loss at a time `t` with parameters `θ` of the neural network. """ @@ -220,7 +217,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, end """ - generate_loss(strategy, phi, f, autodiff, tspan, p, batch) + generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) Representation of the loss function, parametric on the training strategy `strategy`. """ @@ -229,14 +226,13 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts] - @assert batch == 0 # not implemented function loss(θ, _) - intprob = IntegralProblem(integrand, (tspan[1], tspan[2]), θ) - sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol) + intf = BatchIntegralFunction(integrand, max_batch = strategy.batch) + intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ) + sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, reltol = strategy.reltol, maxiters = strategy.maxiters) sol.u end - return loss end @@ -395,16 +391,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, alg.strategy end - batch = if alg.batch === nothing - if strategy isa QuadratureTraining - strategy.batch - else - true - end - else - alg.batch - end - + batch = alg.batch inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) additional_loss = alg.additional_loss (param_estim && isnothing(additional_loss)) && throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true).")) diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 1db8780941..c997e6c4cc 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -272,7 +272,7 @@ struct QuadratureTraining{Q <: SciMLBase.AbstractIntegralAlgorithm, T} <: batch::Int64 end -function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-6, abstol = 1e-3, +function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-3, abstol = 1e-6, maxiters = 1_000, batch = 100) QuadratureTraining(quadrature_alg, reltol, abstol, maxiters, batch) end @@ -306,11 +306,7 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature end area = eltypeθ(prod(abs.(ub .- lb))) f_ = (lb, ub, loss_, θ) -> begin - # last_x = 1 function integrand(x, θ) - # last_x = x - # mean(abs2,loss_(x,θ), dims=2) - # size_x = fill(size(x)[2],(1,1)) x = adapt(parameterless_type(ComponentArrays.getdata(θ)), x) sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x end From a80cd034e68e3482ff525c82cf1303023cacda59 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Thu, 21 Mar 2024 08:36:15 +0000 Subject: [PATCH 2/8] test: remove tests for assertion error with batch to be true for QuadratureTraining --- test/NNODE_tests.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 122adcceb3..b3731f059d 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -69,9 +69,6 @@ end sol = solve(prob, NNODE(luxchain, opt), verbose = true, maxiters = 400) @test sol.errors[:l2] < 0.5 - @test_throws AssertionError solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, - maxiters = 400) - sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), @@ -105,10 +102,6 @@ end abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 - @test_throws AssertionError solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, - maxiters = 400, - abstol = 1.0f-8) - sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), From 118102989f25117c6aa919480578526d5422db91 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Thu, 21 Mar 2024 12:19:04 +0000 Subject: [PATCH 3/8] test: make verbose = false for NNODE tests --- test/NNODE_tests.jl | 61 +++++++++++++++++++++++---------------- test/NNODE_tstops_test.jl | 21 ++++++++++---- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index b3731f059d..8475737e05 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -9,6 +9,7 @@ Random.seed!(100) @testset "Scalar" begin # Run a solve on scalars + println("Scalar") linear = (u, p, t) -> cos(2pi * t) tspan = (0.0f0, 1.0f0) u0 = 0.0f0 @@ -16,26 +17,27 @@ Random.seed!(100) luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95)) - sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = true, + sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) @test_throws ArgumentError solve(prob, NNODE(luxchain, opt; autodiff = true), dt = 1 / 20.0f0, - verbose = true, abstol = 1.0f-10, maxiters = 200) + verbose = false, abstol = 1.0f-10, maxiters = 200) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, abstol = 1.0f-6, maxiters = 200) opt = OptimizationOptimJL.BFGS() - sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = true, + sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, abstol = 1.0f-6, maxiters = 200) end @testset "Vector" begin # Run a solve on vectors + println("Vector") linear = (u, p, t) -> [cos(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0] @@ -44,14 +46,14 @@ end opt = OptimizationOptimJL.BFGS() sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, abstol = 1e-10, - verbose = true, maxiters = 200) + verbose = false, maxiters = 200) @test_throws ArgumentError solve(prob, NNODE(luxchain, opt; autodiff = true), dt = 1 / 20.0f0, - abstol = 1e-10, verbose = true, maxiters = 200) + abstol = 1e-10, verbose = false, maxiters = 200) sol = solve(prob, NNODE(luxchain, opt), abstol = 1.0f-6, - verbose = true, maxiters = 200) + verbose = false, maxiters = 200) @test sol(0.5) isa Vector @test sol(0.5; idxs = 1) isa Number @@ -59,6 +61,7 @@ end end @testset "Example 1" begin + println("Example 1") linear = (u, p, t) -> @. t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) - u * (t + ((1 + 3 * (t^2)) / (1 + t + t^3))) linear_analytic = (u0, p, t) -> [exp(-(t^2) / 2) / (1 + t + t^3) + t^2] @@ -66,68 +69,70 @@ end luxchain = Lux.Chain(Lux.Dense(1, 128, Lux.σ), Lux.Dense(128, 1)) opt = OptimizationOptimisers.Adam(0.01) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, maxiters = 400) + sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400) + verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400) + verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 2" begin + println("Example 2") linear = (u, p, t) -> -u / 5 + exp(-t / 5) .* cos(t) linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t)) prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0)) luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, maxiters = 400, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400, + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400, + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 3" begin + println("Example 3") linear = (u, p, t) -> [cos(2pi * t), sin(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0, -1.0f0 / 2pi] @@ -139,13 +144,14 @@ end alg = NNODE(luxchain, opt; autodiff = false) sol = solve(prob, - alg, verbose = true, dt = 1 / 40.0f0, + alg, verbose = false, dt = 1 / 40.0f0, maxiters = 2000, abstol = 1.0f-7) @test sol.errors[:l2] < 0.5 end @testset "Training Strategies" begin @testset "WeightedIntervalTraining" begin + println("WeightedIntervalTraining") function f(u, p, t) [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] end @@ -162,7 +168,7 @@ end points = 200 alg = NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01) + sol = solve(prob_oop, alg, verbose = false, maxiters = 5000, saveat = 0.01) @test abs(mean(sol) - mean(true_sol)) < 0.2 end @@ -176,6 +182,7 @@ end u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x) @testset "GridTraining" begin + println("GridTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -183,22 +190,24 @@ end end alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01), additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-8, maxiters = 500) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end @testset "QuadratureTraining" begin + println("QuadratureTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-10, maxiters = 200) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end @testset "StochasticTraining" begin + println("StochasticTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -206,12 +215,13 @@ end end alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000), additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-8, maxiters = 500) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end end @testset "Parameter Estimation" begin + println("Parameter Estimation") function lorenz(u, p, t) return [p[1]*(u[2]-u[1]), u[1]*(p[2]-u[3])-u[2], @@ -235,12 +245,13 @@ end ) opt = OptimizationOptimJL.LBFGS(linesearch = BackTracking()) alg = NNODE(luxchain, opt, strategy = GridTraining(0.01), param_estim = true, additional_loss = additional_loss) - sol = solve(prob, alg, verbose = true, abstol = 1e-8, maxiters = 5000, saveat = t_) + sol = solve(prob, alg, verbose = false, abstol = 1e-8, maxiters = 5000, saveat = t_) @test sol.k.u.p≈true_p atol=1e-2 @test reduce(hcat, sol.u)≈u_ atol=1e-2 end @testset "Translating from Flux" begin + println("Translating from Flux") linear = (u, p, t) -> cos(2pi * t) linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t) tspan = (0.0, 1.0) @@ -252,6 +263,6 @@ end fluxchain = Flux.Chain(Flux.Dense(1, 5, Flux.σ), Flux.Dense(5, 1)) alg1 = NNODE(fluxchain, opt) @test alg1.chain isa Lux.AbstractExplicitLayer - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-10, maxiters = 200) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end diff --git a/test/NNODE_tstops_test.jl b/test/NNODE_tstops_test.jl index c0f8422a09..bc4a4b08d6 100644 --- a/test/NNODE_tstops_test.jl +++ b/test/NNODE_tstops_test.jl @@ -31,46 +31,55 @@ points = 3 dx = 1.0 @testset "GridTraining" begin + println("GridTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end @testset "WeightedIntervalTraining" begin + println("WeightedIntervalTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end @testset "StochasticTraining" begin + println("StochasticTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end From c74384caa5fb7a5e360d5c3a98b85a6fa469328c Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 22 Mar 2024 07:11:24 +0000 Subject: [PATCH 4/8] test: use BFGS instead of LBFGS for parameter estimation in NNODE --- test/NNODE_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 8475737e05..1e2ba3c055 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -243,9 +243,9 @@ end Lux.Dense(n, n, Lux.σ), Lux.Dense(n, 3) ) - opt = OptimizationOptimJL.LBFGS(linesearch = BackTracking()) + opt = OptimizationOptimJL.BFGS(linesearch = BackTracking()) alg = NNODE(luxchain, opt, strategy = GridTraining(0.01), param_estim = true, additional_loss = additional_loss) - sol = solve(prob, alg, verbose = false, abstol = 1e-8, maxiters = 5000, saveat = t_) + sol = solve(prob, alg, verbose = false, abstol = 1e-8, maxiters = 1000, saveat = t_) @test sol.k.u.p≈true_p atol=1e-2 @test reduce(hcat, sol.u)≈u_ atol=1e-2 end From 4ab36749592e9b32c79d89940ff1457f0880b94f Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 22 Mar 2024 09:10:30 +0000 Subject: [PATCH 5/8] refactor: use FromFluxAdaptor for converting Flux to Lux as Lux.transform is deprecated --- src/BPINN_ode.jl | 2 +- src/NeuralPDE.jl | 1 + src/advancedHMC_MCMC.jl | 2 +- src/dae_solve.jl | 2 +- src/ode_solve.jl | 4 ++-- src/pinn_types.jl | 6 +++--- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 3bbf1afea8..087b9d41cc 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -113,7 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) BNNODE(chain, Kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd, dataset, physdt, MCMCkwargs, diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index d367bf8b6c..2ba1de25b2 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -30,6 +30,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays +using Lux: FromFluxAdaptor using ChainRulesCore: @non_differentiable RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 6f30149257..c86c87599c 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -439,7 +439,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMCkwargs = (n_leapfrog = 30,), progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) # NN parameter prior mean and variance(PriorsNN must be a tuple) if isinplace(prob) throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 3f6bf8f0fb..0c9d1323de 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -42,7 +42,7 @@ end function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) end diff --git a/src/ode_solve.jl b/src/ode_solve.jl index b9c46d3463..ef57beabe8 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -15,7 +15,7 @@ of the physics-informed neural network which is used as a solver for a standard ## Positional Arguments * `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`. - `Flux.Chain` will be converted to `Lux` using `Lux.transform`. + `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`. * `opt`: The optimizer to train the neural network. * `init_params`: The initial parameter of the neural network. By default, this is `nothing` which thus uses the random initialization provided by the neural network library. @@ -90,7 +90,7 @@ end function NNODE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, batch = false, param_estim = false, additional_loss = nothing, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 69116c4da3..50f7649dc6 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -48,7 +48,7 @@ methodology. * `chain`: a vector of Lux/Flux chains with a d-dimensional input and a 1-dimensional output corresponding to each of the dependent variables. Note that this specification respects the order of the dependent variables as specified in the PDESystem. - Flux chains will be converted to Lux internally using `Lux.transform`. + Flux chains will be converted to Lux internally using `adapt(FromFluxAdaptor(false, false), chain)`. * `strategy`: determines which training strategy will be used. See the Training Strategy documentation for more details. @@ -107,7 +107,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput @@ -243,7 +243,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput From 6c94adff6203504f8d7d5c63724ae31645cda660 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 22 Mar 2024 09:28:12 +0000 Subject: [PATCH 6/8] docs: adjust tolerance in QuadratureTraining in neural adapter example --- docs/src/tutorials/neural_adapter.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/tutorials/neural_adapter.md b/docs/src/tutorials/neural_adapter.md index a56e30a269..93f0dd036f 100644 --- a/docs/src/tutorials/neural_adapter.md +++ b/docs/src/tutorials/neural_adapter.md @@ -69,7 +69,7 @@ function loss(cord, θ) ch2 .- phi(cord, res.u) end -strategy = NeuralPDE.QuadratureTraining() +strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6) prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000) @@ -173,7 +173,7 @@ for i in 1:count_decomp bcs_ = create_bcs(domains_[1].domain, phi_bound) @named pde_system_ = PDESystem(eq, bcs_, domains_, [x, y], [u(x, y)]) push!(pde_system_map, pde_system_) - strategy = NeuralPDE.QuadratureTraining() + strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6) discretization = NeuralPDE.PhysicsInformedNN(chains[i], strategy; init_params = init_params[i]) @@ -243,10 +243,10 @@ callback = function (p, l) end prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map, - NeuralPDE.QuadratureTraining()) + NeuralPDE.QuadratureTraining(; reltol = 1e-6)) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000) prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map, - NeuralPDE.QuadratureTraining()) + NeuralPDE.QuadratureTraining(; reltol = 1e-6)) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000) phi_ = PhysicsInformedNN(chain2, strategy; init_params = res_.u).phi From 7f3cc74adfa811f1ff9a1d30bd816d92b9ab732c Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 22 Mar 2024 10:30:19 +0000 Subject: [PATCH 7/8] build: bump lower bound of Lux --- Project.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index d76565ff25..45468985c5 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ AdvancedHMC = "0.6.1" Aqua = "0.8" ArrayInterface = "7.7" CUDA = "5.2" -ChainRulesCore = "1.18" +ChainRulesCore = "1.21" ComponentArrays = "0.15.8" Cubature = "1.5" DiffEqBase = "6.144" @@ -59,7 +59,7 @@ Integrals = "4" LineSearches = "7.2" LinearAlgebra = "1" LogDensityProblems = "2" -Lux = "0.5.14" +Lux = "0.5.22" LuxCUDA = "0.3.2" MCMCChains = "6" MethodOfLines = "0.10.7" @@ -82,7 +82,7 @@ SymbolicUtils = "1.4" Symbolics = "5.17" Test = "1" UnPack = "1" -Zygote = "0.6.68" +Zygote = "0.6.69" julia = "1.10" [extras] @@ -91,12 +91,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" [targets] test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"] From 7e3de9879de54292ac4a0ec8f41375104303120b Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 22 Mar 2024 11:15:10 +0000 Subject: [PATCH 8/8] refactor: make batch to be true by default for NNODE --- src/ode_solve.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index ef57beabe8..8af6a708d9 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -27,11 +27,9 @@ of the physics-informed neural network which is used as a solver for a standard the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). -* `batch`: The batch size for the loss computation. Defaults to `false`, which - means the application of the neural network is done at individual time points one - at a time. `true` means the neural network is applied at a row vector of values - `t` simultaneously, i.e. it's the batch size for the neural network evaluations. - This requires a neural network compatible with batched data. +* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values + `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. + `false` means which means the application of the neural network is done at individual time points one at a time. This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand. * `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network. * `strategy`: The training strategy used to choose the points for the evaluations. @@ -89,7 +87,7 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = false, param_estim = false, additional_loss = nothing, kwargs...) + autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...) !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end