diff --git a/Project.toml b/Project.toml index f02da7c7ca..82d4066745 100644 --- a/Project.toml +++ b/Project.toml @@ -17,9 +17,12 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" @@ -42,23 +45,25 @@ Adapt = "4" AdvancedHMC = "0.6.1" Aqua = "0.8" ArrayInterface = "7.9" -CUDA = "5.2" -ChainRulesCore = "1.21" -ComponentArrays = "0.15.8" +CUDA = "5.3" +ChainRulesCore = "1.24" +ComponentArrays = "0.15.14" Cubature = "1.5" DiffEqNoiseProcess = "5.20" Distributions = "0.25.107" -DocStringExtensions = "0.9" +DocStringExtensions = "0.9.3" DomainSets = "0.6, 0.7" Flux = "0.14.11" ForwardDiff = "0.10.36" -Functors = "0.4.4" +Functors = "0.4.10" Integrals = "4.4" +KernelAbstractions = "0.9.22" LineSearches = "7.2" LinearAlgebra = "1" LogDensityProblems = "2" -Lux = "0.5.22" +Lux = "0.5.57" LuxCUDA = "0.3.2" +LuxDeviceUtils = "0.1.24" MCMCChains = "6" MethodOfLines = "0.11" ModelingToolkit = "9.9" diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 1122afc838..17e2df074e 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -32,7 +32,10 @@ using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays using Lux: FromFluxAdaptor -using ChainRulesCore: @non_differentiable +using ChainRulesCore: @ignore_derivatives +using LuxDeviceUtils: LuxCUDADevice, LuxCPUDevice, cpu_device +using LuxCUDA: CuArray, CUDABackend +using KernelAbstractions: @kernel, @Const, @index RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 64d7b3ac6c..232360dab5 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -1,7 +1,7 @@ abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end """ - NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...) + NNODE(chain, opt, init_params = nothing; autodiff = false, batch = true, additional_loss = nothing, kwargs...) Algorithm for solving ordinary differential equations using a neural network. This is a specialization of the physics-informed neural network which is used as a solver for a standard `ODEProblem`. @@ -21,6 +21,7 @@ of the physics-informed neural network which is used as a solver for a standard which thus uses the random initialization provided by the neural network library. ## Keyword Arguments + * `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions, θ are the weights of the neural network(s). * `autodiff`: The switch between automatic and numerical differentiation for @@ -71,7 +72,7 @@ is an accurate interpolation (up to the neural network training result). In addi Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000. """ -struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, +struct NNODE{C, O, P, B, PE, K, D, AL <: Union{Nothing, Function}, S <: Union{Nothing, AbstractTrainingStrategy} } <: NeuralPDEAlgorithm @@ -83,15 +84,33 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, strategy::S param_estim::PE additional_loss::AL + device::D kwargs::K end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...) + autodiff = false, batch = true, param_estim = false, + additional_loss = nothing, device = cpu_device(), kwargs...) !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, - strategy, param_estim, additional_loss, kwargs) + strategy, param_estim, additional_loss, device, kwargs) +end + +@kernel function custom_broadcast!(f, du, @Const(out), @Const(p), @Const(t)) + i = @index(Global, Linear) + @views @inbounds x = f(out[:, i], p, t[i]) + du[:, i] .= x +end + +gpu_broadcast = custom_broadcast!(CUDABackend()) + +function get_array_type(::LuxCUDADevice) + CuArray +end + +function get_array_type(::LuxCPUDevice) + Array end """ @@ -100,53 +119,41 @@ end Internal struct, used for representing the ODE solution as a neural network in a form that respects boundary conditions, i.e. `phi(t) = u0 + t*NN(t)`. """ -mutable struct ODEPhi{C, T, U, S} +mutable struct ODEPhi{C, T, U, S, D} chain::C t0::T u0::U st::S - function ODEPhi(chain::Lux.AbstractExplicitLayer, t::Number, u0, st) - new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st) + device::D + function ODEPhi(chain::Lux.AbstractExplicitLayer, t0::Number, u0, st, device) + new{typeof(chain), typeof(t0), typeof(u0), typeof(st), typeof(device)}( + chain, t0, u0, st, device) end end -function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) +function generate_phi_θ( + chain::Lux.AbstractExplicitLayer, t0, u0, init_params, device, p, param_estim) θ, st = Lux.setup(Random.default_rng(), chain) isnothing(init_params) && (init_params = θ) - ODEPhi(chain, t, u0, st), init_params -end - -function (f::ODEPhi{C, T, U})(t::Number, - θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} - y, st = f.chain( - adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), [t]), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 + (t - f.t0) * first(y) -end - -function (f::ODEPhi{C, T, U})(t::AbstractVector, - θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} - # Batch via data as row vectors - y, st = f.chain( - adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), t'), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t' .- f.t0) .* y -end - -function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U} - y, st = f.chain( - adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), [t]), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t .- f.t0) .* y + array_type = get_array_type(device) + init_params = if param_estim + ComponentArrays.ComponentArray(; + depvar = init_params, p = p) + else + ComponentArrays.ComponentArray(; + depvar = init_params) + end + u0_ = u0 isa Number ? u0 : array_type(u0) + ODEPhi(chain, t0, u0_, st, device), adapt(array_type, init_params) end -function (f::ODEPhi{C, T, U})(t::AbstractVector, - θ) where {C <: Lux.AbstractExplicitLayer, T, U} +function (f::ODEPhi{C, T, U})( + t::AbstractVector, θ) where {C <: Lux.AbstractExplicitLayer, T, U} # Batch via data as row vectors y, st = f.chain( adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), t'), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t' .- f.t0) .* y + @ignore_derivatives f.st = st + f.u0 .+ (t .- f.t0)' .* y end """ @@ -190,34 +197,37 @@ Simple L2 inner loss at a time `t` with parameters `θ` of the neural network. function inner_loss end function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, param_estim::Bool) where {C, T, U <: Number} - p_ = param_estim ? θ.p : p - sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p_, t)) + p, param_estim::Bool) where {C, T, U} + array_type = get_array_type(phi.device) + p = param_estim ? θ.p : p + p = p isa SciMLBase.NullParameters ? p : array_type(p) + t = array_type([t]) + dxdtguess = ode_dfdx(phi, t, θ, autodiff) + out = phi(t, θ) + fs = rhs(phi.device, f, phi.u0, out, p, t) + sum(abs2, dxdtguess .- fs) end function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, param_estim::Bool) where {C, T, U <: Number} - p_ = param_estim ? θ.p : p + p, param_estim::Bool) where {C, T, U} + array_type = get_array_type(phi.device) + t = array_type(t) + p = param_estim ? θ.p : p + p = p isa SciMLBase.NullParameters ? p : array_type(p) out = phi(t, θ) - fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)]) - dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) + fs = rhs(phi.device, f, phi.u0, out, p, t) + dxdtguess = ode_dfdx(phi, t, θ, autodiff) sum(abs2, dxdtguess .- fs) / length(t) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, param_estim::Bool) where {C, T, U} - p_ = param_estim ? θ.p : p - sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t)) +function rhs(::LuxCPUDevice, f, u0, out, p, t) + u0 isa Number ? reduce(hcat, [f(out[i], p, t[i]) for i in axes(out, 2)]) : + reduce(hcat, [f(out[:, i], p, t[i]) for i in axes(out, 2)]) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, param_estim::Bool) where {C, T, U} - p_ = param_estim ? θ.p : p - out = Array(phi(t, θ)) - arrt = Array(t) - fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)]) - dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) - sum(abs2, dxdtguess .- fs) / length(t) +function rhs(::LuxCUDADevice, f, u0, out, p, t) + du = similar(out) + gpu_broadcast(f, du, out, p, t; workgroupsize = 64, ndrange = 100) end """ @@ -323,8 +333,10 @@ struct NNODEInterpolation{T <: ODEPhi, T2} phi::T θ::T2 end -(f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ) -(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ)[idxs] +function (f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity) + vec(f.phi([t], f.θ)) +end +(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = vec(f.phi([t], f.θ))[idxs] function (f::NNODEInterpolation)(t::Vector, idxs::Nothing, ::Type{Val{0}}, p, continuity) out = f.phi(t, f.θ) @@ -358,36 +370,25 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, p = prob.p t0 = tspan[1] param_estim = alg.param_estim - - #hidden layer chain = alg.chain opt = alg.opt autodiff = alg.autodiff - - #train points generation init_params = alg.init_params + device = alg.device !(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported") - phi, init_params = generate_phi_θ(chain, t0, u0, init_params) - ((eltype(eltype(init_params).types[1]) <: Complex || - eltype(eltype(init_params).types[2]) <: Complex) && + phi, init_params = generate_phi_θ(chain, t0, u0, init_params, device, p, param_estim) + + (eltype(init_params) <: Complex && alg.strategy isa QuadratureTraining) && error("QuadratureTraining cannot be used with complex parameters. Use other strategies.") - init_params = if alg.param_estim - ComponentArrays.ComponentArray(; - depvar = ComponentArrays.ComponentArray(init_params), p = prob.p) - else - ComponentArrays.ComponentArray(; - depvar = ComponentArrays.ComponentArray(init_params)) - end - isinplace(prob) && throw(error("The NNODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) try - phi(t0, init_params) + phi(get_array_type(device)([t0]), init_params) catch err if isa(err, DimensionMismatch) throw(DimensionMismatch("Dimensions of the initial u0 and chain should match")) @@ -473,10 +474,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, ts = [tspan[1], tspan[2]] end + u = phi(ts, res.u) if u0 isa Number - u = [first(phi(t, res.u)) for t in ts] + u = vec(u) else - u = [phi(t, res.u) for t in ts] + u = [u[:, i] for i in 1:size(u, 2)] end sol = SciMLBase.build_solution(prob, alg, ts, u; diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 0cd688e310..7fe631ee33 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -190,7 +190,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01), additional_loss = additional_loss) @@ -203,7 +203,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, additional_loss = additional_loss) sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @@ -215,7 +215,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000), additional_loss = additional_loss)