Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: compatibility of NNODE with CUDA #866

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Expand All @@ -42,23 +45,25 @@ Adapt = "4"
AdvancedHMC = "0.6.1"
Aqua = "0.8"
ArrayInterface = "7.9"
CUDA = "5.2"
ChainRulesCore = "1.21"
ComponentArrays = "0.15.8"
CUDA = "5.3"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.14"
Cubature = "1.5"
DiffEqNoiseProcess = "5.20"
Distributions = "0.25.107"
DocStringExtensions = "0.9"
DocStringExtensions = "0.9.3"
DomainSets = "0.6, 0.7"
Flux = "0.14.11"
ForwardDiff = "0.10.36"
Functors = "0.4.4"
Functors = "0.4.10"
Integrals = "4.4"
KernelAbstractions = "0.9.22"
LineSearches = "7.2"
LinearAlgebra = "1"
LogDensityProblems = "2"
Lux = "0.5.22"
Lux = "0.5.57"
LuxCUDA = "0.3.2"
LuxDeviceUtils = "0.1.24"
MCMCChains = "6"
MethodOfLines = "0.11"
ModelingToolkit = "9.9"
Expand Down
5 changes: 4 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ using SciMLBase: @add_kwonly, parameterless_type
using UnPack: @unpack
import ChainRulesCore, Lux, ComponentArrays
using Lux: FromFluxAdaptor
using ChainRulesCore: @non_differentiable
using ChainRulesCore: @ignore_derivatives
using LuxDeviceUtils: LuxCUDADevice, LuxCPUDevice, cpu_device
using LuxCUDA: CuArray, CUDABackend
using KernelAbstractions: @kernel, @Const, @index

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down
154 changes: 78 additions & 76 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end

"""
NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...)
NNODE(chain, opt, init_params = nothing; autodiff = false, batch = true, additional_loss = nothing, kwargs...)

Algorithm for solving ordinary differential equations using a neural network. This is a specialization
of the physics-informed neural network which is used as a solver for a standard `ODEProblem`.
Expand All @@ -21,6 +21,7 @@ of the physics-informed neural network which is used as a solver for a standard
which thus uses the random initialization provided by the neural network library.

## Keyword Arguments

* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions,
θ are the weights of the neural network(s).
* `autodiff`: The switch between automatic and numerical differentiation for
Expand Down Expand Up @@ -71,7 +72,7 @@ is an accurate interpolation (up to the neural network training result). In addi
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving
ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.
"""
struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
struct NNODE{C, O, P, B, PE, K, D, AL <: Union{Nothing, Function},
S <: Union{Nothing, AbstractTrainingStrategy}
} <:
NeuralPDEAlgorithm
Expand All @@ -83,15 +84,33 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
strategy::S
param_estim::PE
additional_loss::AL
device::D
kwargs::K
end
function NNODE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...)
autodiff = false, batch = true, param_estim = false,
additional_loss = nothing, device = cpu_device(), kwargs...)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
NNODE(chain, opt, init_params, autodiff, batch,
strategy, param_estim, additional_loss, kwargs)
strategy, param_estim, additional_loss, device, kwargs)
end

@kernel function custom_broadcast!(f, du, @Const(out), @Const(p), @Const(t))
i = @index(Global, Linear)
@views @inbounds x = f(out[:, i], p, t[i])
du[:, i] .= x
end

gpu_broadcast = custom_broadcast!(CUDABackend())

function get_array_type(::LuxCUDADevice)
CuArray
end

function get_array_type(::LuxCPUDevice)
Array
end

"""
Expand All @@ -100,53 +119,41 @@ end
Internal struct, used for representing the ODE solution as a neural network in a form that respects boundary conditions, i.e.
`phi(t) = u0 + t*NN(t)`.
"""
mutable struct ODEPhi{C, T, U, S}
mutable struct ODEPhi{C, T, U, S, D}
chain::C
t0::T
u0::U
st::S
function ODEPhi(chain::Lux.AbstractExplicitLayer, t::Number, u0, st)
new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st)
device::D
function ODEPhi(chain::Lux.AbstractExplicitLayer, t0::Number, u0, st, device)
new{typeof(chain), typeof(t0), typeof(u0), typeof(st), typeof(device)}(
chain, t0, u0, st, device)
end
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
function generate_phi_θ(
chain::Lux.AbstractExplicitLayer, t0, u0, init_params, device, p, param_estim)
θ, st = Lux.setup(Random.default_rng(), chain)
isnothing(init_params) && (init_params = θ)
ODEPhi(chain, t, u0, st), init_params
end

function (f::ODEPhi{C, T, U})(t::Number,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 + (t - f.t0) * first(y)
end

function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
# Batch via data as row vectors
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t .- f.t0) .* y
array_type = get_array_type(device)
init_params = if param_estim
ComponentArrays.ComponentArray(;
depvar = init_params, p = p)
else
ComponentArrays.ComponentArray(;
depvar = init_params)
end
u0_ = u0 isa Number ? u0 : array_type(u0)
ODEPhi(chain, t0, u0_, st, device), adapt(array_type, init_params)
end

function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U}
function (f::ODEPhi{C, T, U})(
t::AbstractVector, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# Batch via data as row vectors
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
@ignore_derivatives f.st = st
f.u0 .+ (t .- f.t0)' .* y
end

"""
Expand Down Expand Up @@ -190,34 +197,37 @@ Simple L2 inner loss at a time `t` with parameters `θ` of the neural network.
function inner_loss end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, param_estim::Bool) where {C, T, U <: Number}
p_ = param_estim ? θ.p : p
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p_, t))
p, param_estim::Bool) where {C, T, U}
array_type = get_array_type(phi.device)
p = param_estim ? θ.p : p
p = p isa SciMLBase.NullParameters ? p : array_type(p)
t = array_type([t])
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
out = phi(t, θ)
fs = rhs(phi.device, f, phi.u0, out, p, t)
sum(abs2, dxdtguess .- fs)
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, param_estim::Bool) where {C, T, U <: Number}
p_ = param_estim ? θ.p : p
p, param_estim::Bool) where {C, T, U}
array_type = get_array_type(phi.device)
t = array_type(t)
p = param_estim ? θ.p : p
p = p isa SciMLBase.NullParameters ? p : array_type(p)
out = phi(t, θ)
fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
fs = rhs(phi.device, f, phi.u0, out, p, t)
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
sum(abs2, dxdtguess .- fs) / length(t)
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, param_estim::Bool) where {C, T, U}
p_ = param_estim ? θ.p : p
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t))
function rhs(::LuxCPUDevice, f, u0, out, p, t)
u0 isa Number ? reduce(hcat, [f(out[i], p, t[i]) for i in axes(out, 2)]) :
reduce(hcat, [f(out[:, i], p, t[i]) for i in axes(out, 2)])
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, param_estim::Bool) where {C, T, U}
p_ = param_estim ? θ.p : p
out = Array(phi(t, θ))
arrt = Array(t)
fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, dxdtguess .- fs) / length(t)
function rhs(::LuxCUDADevice, f, u0, out, p, t)
du = similar(out)
gpu_broadcast(f, du, out, p, t; workgroupsize = 64, ndrange = 100)
end

"""
Expand Down Expand Up @@ -323,8 +333,10 @@ struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
θ::T2
end
(f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ)
(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ)[idxs]
function (f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity)
vec(f.phi([t], f.θ))
end
(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = vec(f.phi([t], f.θ))[idxs]

function (f::NNODEInterpolation)(t::Vector, idxs::Nothing, ::Type{Val{0}}, p, continuity)
out = f.phi(t, f.θ)
Expand Down Expand Up @@ -358,36 +370,25 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
p = prob.p
t0 = tspan[1]
param_estim = alg.param_estim

#hidden layer
chain = alg.chain
opt = alg.opt
autodiff = alg.autodiff

#train points generation
init_params = alg.init_params
device = alg.device

!(chain isa Lux.AbstractExplicitLayer) &&
error("Only Lux.AbstractExplicitLayer neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
((eltype(eltype(init_params).types[1]) <: Complex ||
eltype(eltype(init_params).types[2]) <: Complex) &&
phi, init_params = generate_phi_θ(chain, t0, u0, init_params, device, p, param_estim)

(eltype(init_params) <: Complex &&
alg.strategy isa QuadratureTraining) &&
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")

init_params = if alg.param_estim
ComponentArrays.ComponentArray(;
depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)
else
ComponentArrays.ComponentArray(;
depvar = ComponentArrays.ComponentArray(init_params))
end

isinplace(prob) &&
throw(error("The NNODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

try
phi(t0, init_params)
phi(get_array_type(device)([t0]), init_params)
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match"))
Expand Down Expand Up @@ -473,10 +474,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
ts = [tspan[1], tspan[2]]
end

u = phi(ts, res.u)
if u0 isa Number
u = [first(phi(t, res.u)) for t in ts]
u = vec(u)
else
u = [phi(t, res.u) for t in ts]
u = [u[:, i] for i in 1:size(u, 2)]
end

sol = SciMLBase.build_solution(prob, alg, ts, u;
Expand Down
6 changes: 3 additions & 3 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ end
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
end
alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01),
additional_loss = additional_loss)
Expand All @@ -203,7 +203,7 @@ end
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
end
alg1 = NNODE(luxchain, opt, additional_loss = additional_loss)
sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200)
Expand All @@ -215,7 +215,7 @@ end
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
end
alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000),
additional_loss = additional_loss)
Expand Down
Loading